Skip to content

Flax

formed.integrations.flax.distributors

Distributed computing abstractions for Flax models.

This module provides abstractions for distributed training across multiple devices, supporting both single-device and data-parallel training strategies.

Key Components
  • BaseDistributor: Abstract interface for device distribution strategies
  • SingleDeviceDistributor: No-op distributor for single-device training
  • DataParallelDistributor: Data-parallel training using JAX pmap
Features
  • Transparent device sharding and replication
  • JIT compilation with device-specific optimizations
  • Reduction operations (mean, sum) across devices
  • Compatible with FlaxTrainer

Examples:

>>> from formed.integrations.flax import DataParallelDistributor
>>> import jax
>>>
>>> # Create data-parallel distributor for all available devices
>>> distributor = DataParallelDistributor(axis_name="batch")
>>>
>>> # Shard batch across devices
>>> sharded_batch = distributor.shard(batch)
>>>
>>> # Map function across devices
>>> train_step = distributor.map(training_function)
>>> outputs = train_step(sharded_batch, state)

BaseDistributor

Bases: Registrable, ABC, Generic[ModelInputT]

Abstract base class for device distribution strategies.

BaseDistributor defines the interface for distributing computations across devices in a JAX/Flax training pipeline. It handles data sharding, replication, and reduction operations.

CLASS TYPE PARAMETER DESCRIPTION
ModelInputT

Type of model input data.

shard

shard(inputs)

Shard inputs across devices.

PARAMETER DESCRIPTION
inputs

Input data to shard.

TYPE: ModelInputT

RETURNS DESCRIPTION
ModelInputT

Sharded input data with an additional device dimension.

Source code in src/formed/integrations/flax/distributors.py
62
63
64
65
66
67
68
69
70
71
72
def shard(self, inputs: ModelInputT) -> ModelInputT:
    """Shard inputs across devices.

    Args:
        inputs: Input data to shard.

    Returns:
        Sharded input data with an additional device dimension.

    """
    return inputs

replicate

replicate(inputs)

Replicate data across all devices.

PARAMETER DESCRIPTION
inputs

Data to replicate.

TYPE: _T

RETURNS DESCRIPTION
_T

Replicated data with device dimension.

Source code in src/formed/integrations/flax/distributors.py
74
75
76
77
78
79
80
81
82
83
84
def replicate(self, inputs: _T) -> _T:
    """Replicate data across all devices.

    Args:
        inputs: Data to replicate.

    Returns:
        Replicated data with device dimension.

    """
    return inputs

unreplicate

unreplicate(inputs)

Extract data from the first device, removing device dimension.

PARAMETER DESCRIPTION
inputs

Replicated data with device dimension.

TYPE: _T

RETURNS DESCRIPTION
_T

Data from first device without device dimension.

Source code in src/formed/integrations/flax/distributors.py
86
87
88
89
90
91
92
93
94
95
96
def unreplicate(self, inputs: _T) -> _T:
    """Extract data from the first device, removing device dimension.

    Args:
        inputs: Replicated data with device dimension.

    Returns:
        Data from first device without device dimension.

    """
    return inputs

map abstractmethod

map(fn, static_argnums=())

Map a function across devices with JIT compilation.

PARAMETER DESCRIPTION
fn

Function to map across devices.

TYPE: _CallableT

static_argnums

Indices of static arguments.

TYPE: Sequence[int] DEFAULT: ()

RETURNS DESCRIPTION
_CallableT

Mapped and compiled function.

Source code in src/formed/integrations/flax/distributors.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
@abc.abstractmethod
def map(self, fn: _CallableT, static_argnums: Sequence[int] = ()) -> _CallableT:
    """Map a function across devices with JIT compilation.

    Args:
        fn: Function to map across devices.
        static_argnums: Indices of static arguments.

    Returns:
        Mapped and compiled function.

    """
    raise NotImplementedError

reduce abstractmethod

reduce(array, op='mean')

Reduce an array across devices.

PARAMETER DESCRIPTION
array

Array to reduce.

TYPE: _ArrayT

op

Reduction operation ("mean" or "sum").

TYPE: _ReduceOp DEFAULT: 'mean'

RETURNS DESCRIPTION
_ArrayT

Reduced array.

Source code in src/formed/integrations/flax/distributors.py
112
113
114
115
116
117
118
119
120
121
122
123
124
@abc.abstractmethod
def reduce(self, array: _ArrayT, op: _ReduceOp = "mean") -> _ArrayT:
    """Reduce an array across devices.

    Args:
        array: Array to reduce.
        op: Reduction operation (`"mean"` or `"sum"`).

    Returns:
        Reduced array.

    """
    raise NotImplementedError

SingleDeviceDistributor

Bases: BaseDistributor[ModelInputT]

Distributor for single-device training.

This distributor applies JIT compilation without any device distribution. All shard, replicate, and unreplicate operations are no-ops.

Examples:

>>> distributor = SingleDeviceDistributor()
>>> train_step = distributor.map(my_train_function)
>>> output = train_step(batch, state, trainer)

map

map(fn, static_argnums=())

Apply JIT compilation to a function.

PARAMETER DESCRIPTION
fn

Function to compile.

TYPE: _CallableT

static_argnums

Indices of static arguments.

TYPE: Sequence[int] DEFAULT: ()

RETURNS DESCRIPTION
_CallableT

JIT-compiled function.

Source code in src/formed/integrations/flax/distributors.py
141
142
143
144
145
146
147
148
149
150
151
152
def map(self, fn: _CallableT, static_argnums: Sequence[int] = ()) -> _CallableT:
    """Apply JIT compilation to a function.

    Args:
        fn: Function to compile.
        static_argnums: Indices of static arguments.

    Returns:
        JIT-compiled function.

    """
    return cast(_CallableT, nnx.jit(fn, static_argnums=static_argnums))

reduce

reduce(array, op='mean')

Return array unchanged (no reduction needed for single device).

PARAMETER DESCRIPTION
array

Input array.

TYPE: _ArrayT

op

Reduction operation (ignored).

TYPE: _ReduceOp DEFAULT: 'mean'

RETURNS DESCRIPTION
_ArrayT

Input array unchanged.

Source code in src/formed/integrations/flax/distributors.py
154
155
156
157
158
159
160
161
162
163
164
165
def reduce(self, array: _ArrayT, op: _ReduceOp = "mean") -> _ArrayT:
    """Return array unchanged (no reduction needed for single device).

    Args:
        array: Input array.
        op: Reduction operation (ignored).

    Returns:
        Input array unchanged.

    """
    return array

shard

shard(inputs)

Shard inputs across devices.

PARAMETER DESCRIPTION
inputs

Input data to shard.

TYPE: ModelInputT

RETURNS DESCRIPTION
ModelInputT

Sharded input data with an additional device dimension.

Source code in src/formed/integrations/flax/distributors.py
62
63
64
65
66
67
68
69
70
71
72
def shard(self, inputs: ModelInputT) -> ModelInputT:
    """Shard inputs across devices.

    Args:
        inputs: Input data to shard.

    Returns:
        Sharded input data with an additional device dimension.

    """
    return inputs

replicate

replicate(inputs)

Replicate data across all devices.

PARAMETER DESCRIPTION
inputs

Data to replicate.

TYPE: _T

RETURNS DESCRIPTION
_T

Replicated data with device dimension.

Source code in src/formed/integrations/flax/distributors.py
74
75
76
77
78
79
80
81
82
83
84
def replicate(self, inputs: _T) -> _T:
    """Replicate data across all devices.

    Args:
        inputs: Data to replicate.

    Returns:
        Replicated data with device dimension.

    """
    return inputs

unreplicate

unreplicate(inputs)

Extract data from the first device, removing device dimension.

PARAMETER DESCRIPTION
inputs

Replicated data with device dimension.

TYPE: _T

RETURNS DESCRIPTION
_T

Data from first device without device dimension.

Source code in src/formed/integrations/flax/distributors.py
86
87
88
89
90
91
92
93
94
95
96
def unreplicate(self, inputs: _T) -> _T:
    """Extract data from the first device, removing device dimension.

    Args:
        inputs: Replicated data with device dimension.

    Returns:
        Data from first device without device dimension.

    """
    return inputs

DataParallelDistributor

DataParallelDistributor(
    axis_name="batch", num_devices=None
)

Bases: BaseDistributor[ModelInputT]

Distributor for data-parallel training across multiple devices.

This distributor uses JAX's pmap to execute the same computation on different data shards across multiple devices. Data is automatically sharded along the batch dimension.

PARAMETER DESCRIPTION
axis_name

Name for the device axis (used in reduction operations).

TYPE: str DEFAULT: 'batch'

num_devices

Number of devices to use. Defaults to all local devices.

TYPE: int | None DEFAULT: None

Examples:

>>> # Train on 4 GPUs with data parallelism
>>> distributor = DataParallelDistributor(
...     axis_name="batch",
...     num_devices=4
... )
>>>
>>> # Shard batch of size 32 into 4 shards of size 8
>>> sharded = distributor.shard(batch)
>>> assert sharded.shape == (4, 8, ...)
>>>
>>> # Map training step across devices
>>> train_step = distributor.map(my_train_step)
Note

Batch size must be divisible by num_devices for proper sharding.

Source code in src/formed/integrations/flax/distributors.py
199
200
201
202
203
204
205
def __init__(
    self,
    axis_name: str = "batch",
    num_devices: int | None = None,
) -> None:
    self._axis_name = axis_name
    self._num_devices = num_devices or jax.local_device_count()

shard

shard(inputs)

Shard inputs along the batch dimension across devices.

PARAMETER DESCRIPTION
inputs

Input data with batch dimension.

TYPE: ModelInputT

RETURNS DESCRIPTION
ModelInputT

Sharded inputs with shape (num_devices, batch_per_device, ...).

Source code in src/formed/integrations/flax/distributors.py
207
208
209
210
211
212
213
214
215
216
217
def shard(self, inputs: ModelInputT) -> ModelInputT:
    """Shard inputs along the batch dimension across devices.

    Args:
        inputs: Input data with batch dimension.

    Returns:
        Sharded inputs with shape (num_devices, batch_per_device, ...).

    """
    return jax.tree_util.tree_map(lambda x: x.reshape((self._num_devices, -1) + x.shape[1:]), inputs)

replicate

replicate(inputs)

Replicate data across all devices.

PARAMETER DESCRIPTION
inputs

Data to replicate.

TYPE: _T

RETURNS DESCRIPTION
_T

Replicated data with device dimension.

Source code in src/formed/integrations/flax/distributors.py
219
220
221
222
223
224
225
226
227
228
229
def replicate(self, inputs: _T) -> _T:
    """Replicate data across all devices.

    Args:
        inputs: Data to replicate.

    Returns:
        Replicated data with device dimension.

    """
    return flax.jax_utils.replicate(inputs)

unreplicate

unreplicate(inputs)

Extract data from the first device.

PARAMETER DESCRIPTION
inputs

Replicated data.

TYPE: _T

RETURNS DESCRIPTION
_T

Data from first device.

Source code in src/formed/integrations/flax/distributors.py
231
232
233
234
235
236
237
238
239
240
241
def unreplicate(self, inputs: _T) -> _T:
    """Extract data from the first device.

    Args:
        inputs: Replicated data.

    Returns:
        Data from first device.

    """
    return flax.jax_utils.unreplicate(inputs)

map

map(fn, static_argnums=())

Map function across devices using pmap.

PARAMETER DESCRIPTION
fn

Function to parallelize.

TYPE: _CallableT

static_argnums

Indices of static arguments to broadcast.

TYPE: Sequence[int] DEFAULT: ()

RETURNS DESCRIPTION
_CallableT

Parallelized function using pmap.

Source code in src/formed/integrations/flax/distributors.py
243
244
245
246
247
248
249
250
251
252
253
254
def map(self, fn: _CallableT, static_argnums: Sequence[int] = ()) -> _CallableT:
    """Map function across devices using pmap.

    Args:
        fn: Function to parallelize.
        static_argnums: Indices of static arguments to broadcast.

    Returns:
        Parallelized function using pmap.

    """
    return nnx.pmap(fn, axis_name=self._axis_name, static_broadcasted_argnums=static_argnums)

reduce

reduce(array, op='mean')

Reduce array across devices.

PARAMETER DESCRIPTION
array

Array to reduce across device dimension.

TYPE: _ArrayT

op

Reduction operation - "sum" or "mean".

TYPE: _ReduceOp DEFAULT: 'mean'

RETURNS DESCRIPTION
_ArrayT

Reduced array.

RAISES DESCRIPTION
ValueError

If unsupported reduction operation is specified.

Source code in src/formed/integrations/flax/distributors.py
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
def reduce(self, array: _ArrayT, op: _ReduceOp = "mean") -> _ArrayT:
    """Reduce array across devices.

    Args:
        array: Array to reduce across device dimension.
        op: Reduction operation - `"sum"` or `"mean"`.

    Returns:
        Reduced array.

    Raises:
        ValueError: If unsupported reduction operation is specified.

    """
    if op == "sum":
        return jax.lax.psum(array, axis_name=self._axis_name)
    elif op == "mean":
        return jax.lax.pmean(array, axis_name=self._axis_name)
    raise ValueError(f"Unsupported reduce operation: {op}")

formed.integrations.flax.model

Base model abstraction for Flax NNX models.

This module provides the base class for all Flax models in the framework, integrating Flax NNX with the registrable pattern for configuration-based model instantiation.

Key Features
  • Integration with Flax NNX Module system
  • Registrable pattern for configuration-based instantiation
  • Generic type support for inputs, outputs, and parameters
  • Compatible with FlaxTrainer for end-to-end training

Examples:

>>> from formed.integrations.flax import BaseFlaxModel
>>> from flax import nnx
>>> import jax
>>>
>>> @BaseFlaxModel.register("my_model")
... class MyModel(BaseFlaxModel[dict, jax.Array, None]):
...     def __init__(self, rngs: nnx.Rngs, hidden_dim: int):
...         self.linear = nnx.Linear(10, hidden_dim, rngs=rngs)
...
...     def __call__(self, inputs: dict, params: None = None) -> jax.Array:
...         return self.linear(inputs["features"])

BaseFlaxModel

Bases: Module, Registrable, Generic[ModelInputT, ModelOutputT, ModelParamsT]

Base class for all Flax NNX models in the framework.

This class combines Flax's NNX Module with the registrable pattern, allowing models to be instantiated from configuration files and seamlessly integrated with the training infrastructure.

CLASS TYPE PARAMETER DESCRIPTION
ModelInputT

Type of input data to the model.

ModelOutputT

Type of model output.

ModelParamsT

Type of additional parameters (typically None or a dataclass).

Note

Subclasses should implement __call__ to define the forward pass. Models are automatically compatible with FlaxTrainer when registered.

formed.integrations.flax.modules.embedders

Text embedding modules for Flax models.

This module provides embedders that convert tokenized text into dense vector representations. Embedders handle various text representations including surface forms, part-of-speech tags, and character sequences.

Key Components
  • BaseEmbedder: Abstract base class for all embedders
  • TokenEmbedder: Embeds token ID sequences into dense vectors
  • AnalyzedTextEmbedder: Combines multiple embedding types (surface, POS, chars)
Features
  • Support for nested token sequences (e.g., word -> character)
  • Automatic masking and padding handling
  • Configurable vectorization for character-level embeddings
  • Concatenation of multiple embedding types

Examples:

>>> from formed.integrations.flax.modules import TokenEmbedder, AnalyzedTextEmbedder
>>> from flax import nnx
>>>
>>> # Simple token embedder
>>> embedder = TokenEmbedder(
...     vocab_size=10000,
...     embedding_dim=128,
...     rngs=nnx.Rngs(0)
... )
>>>
>>> # Multi-feature embedder
>>> embedder = AnalyzedTextEmbedder(
...     surface=TokenEmbedder(vocab_size=10000, embedding_dim=128, rngs=rngs),
...     postag=TokenEmbedder(vocab_size=50, embedding_dim=32, rngs=rngs)
... )

EmbedderOutput

Bases: NamedTuple

Output from an embedder.

ATTRIBUTE DESCRIPTION
embeddings

Dense embeddings of shape (batch_size, seq_len, embedding_dim).

TYPE: Array

mask

Attention mask of shape (batch_size, seq_len).

TYPE: Array

embeddings instance-attribute

embeddings

mask instance-attribute

mask

BaseEmbedder

Bases: Module, Registrable, Generic[_TextBatchT], ABC

Abstract base class for text embedders.

Embedders convert tokenized text into dense vector representations. They output both embeddings and attention masks.

CLASS TYPE PARAMETER DESCRIPTION
_TextBatchT

Type of input batch (e.g., IIDSequenceBatch, IAnalyzedTextBatch).

get_output_dim abstractmethod

get_output_dim()

Get the output embedding dimension.

RETURNS DESCRIPTION
int

Embedding dimension.

Source code in src/formed/integrations/flax/modules/embedders.py
89
90
91
92
93
94
95
96
97
@abc.abstractmethod
def get_output_dim(self) -> int:
    """Get the output embedding dimension.

    Returns:
        Embedding dimension.

    """
    raise NotImplementedError

TokenEmbedder

TokenEmbedder(
    vocab_size, embedding_dim, *, vectorizer=None, rngs=None
)

Bases: BaseEmbedder[IIDSequenceBatch]

Embedder for token ID sequences.

This embedder converts token IDs into dense embeddings using a learned embedding matrix. It supports both 2D (batch_size, seq_len) and 3D (batch_size, seq_len, char_len) token ID tensors.

For 3D inputs (e.g., character-level tokens within words), the embedder can either average the embeddings or apply a custom vectorizer.

PARAMETER DESCRIPTION
vocab_size

Size of the vocabulary.

TYPE: int

embedding_dim

Dimension of the embedding vectors.

TYPE: int

vectorizer

Optional vectorizer for 3D inputs (character sequences).

TYPE: BaseSequenceVectorizer | None DEFAULT: None

rngs

Random number generators.

TYPE: Rngs | None DEFAULT: None

Examples:

>>> # Simple word embeddings
>>> embedder = TokenEmbedder(vocab_size=10000, embedding_dim=128, rngs=rngs)
>>> output = embedder(word_ids_batch)
>>>
>>> # Character-level embeddings with pooling
>>> from formed.integrations.flax.modules import BagOfEmbeddingsSequenceVectorizer
>>> embedder = TokenEmbedder(
...     vocab_size=256,
...     embedding_dim=32,
...     vectorizer=BagOfEmbeddingsSequenceVectorizer(pooling="max"),
...     rngs=rngs
... )
Source code in src/formed/integrations/flax/modules/embedders.py
133
134
135
136
137
138
139
140
141
142
143
def __init__(
    self,
    vocab_size: int,
    embedding_dim: int,
    *,
    vectorizer: BaseSequenceVectorizer | None = None,
    rngs: nnx.Rngs | None = None,
) -> None:
    rngs = rngs or require_rngs()
    self._embedding = nnx.Embed(num_embeddings=vocab_size, features=embedding_dim, rngs=rngs)
    self._vectorizer = vectorizer

get_output_dim

get_output_dim()
Source code in src/formed/integrations/flax/modules/embedders.py
175
176
def get_output_dim(self) -> int:
    return self._embedding.features

AnalyzedTextEmbedder

AnalyzedTextEmbedder(
    surface=None, postag=None, character=None
)

Bases: BaseEmbedder[IAnalyzedTextBatch]

Embedder for analyzed text with multiple linguistic features.

This embedder combines embeddings from multiple linguistic representations (surface forms, part-of-speech tags, character sequences) by concatenating them along the feature dimension.

PARAMETER DESCRIPTION
surface

Optional embedder for surface form tokens.

TYPE: BaseEmbedder[IIDSequenceBatch] | None DEFAULT: None

postag

Optional embedder for part-of-speech tags.

TYPE: BaseEmbedder[IIDSequenceBatch] | None DEFAULT: None

character

Optional embedder for character sequences.

TYPE: BaseEmbedder[IIDSequenceBatch] | None DEFAULT: None

RAISES DESCRIPTION
ValueError

If all embedders are None (at least one is required).

Examples:

>>> from formed.integrations.flax.modules import (
...     AnalyzedTextEmbedder,
...     TokenEmbedder
... )
>>>
>>> embedder = AnalyzedTextEmbedder(
...     surface=TokenEmbedder(vocab_size=10000, embedding_dim=128, rngs=rngs),
...     postag=TokenEmbedder(vocab_size=50, embedding_dim=32, rngs=rngs),
...     character=TokenEmbedder(vocab_size=256, embedding_dim=32, rngs=rngs)
... )
>>>
>>> # Output dimension is sum of all embedding dimensions (128 + 32 + 32 = 192)
>>> assert embedder.get_output_dim() == 192
Note

All provided embedders share the same mask, which is taken from the last non-None embedder processed.

Source code in src/formed/integrations/flax/modules/embedders.py
216
217
218
219
220
221
222
223
224
225
226
227
def __init__(
    self,
    surface: BaseEmbedder[IIDSequenceBatch] | None = None,
    postag: BaseEmbedder[IIDSequenceBatch] | None = None,
    character: BaseEmbedder[IIDSequenceBatch] | None = None,
) -> None:
    if all(embedder is None for embedder in (surface, postag, character)):
        raise ValueError("At least one embedder must be provided for AnalyzedTextEmbedder.")

    self._surface = surface
    self._postag = postag
    self._character = character

get_output_dim

get_output_dim()
Source code in src/formed/integrations/flax/modules/embedders.py
247
248
249
250
251
252
def get_output_dim(self) -> int:
    return sum(
        embedder.get_output_dim()
        for embedder in (self._surface, self._postag, self._character)
        if embedder is not None
    )

formed.integrations.flax.modules.encoders

Sequence encoding modules for Flax models.

This module provides encoders that process sequential data, including position encoders, RNN-based encoders, and transformer encoders.

Key Components

Position Encoders: - BasePositionEncoder: Abstract base for position encoding - SinusoidalPositionEncoder: Sinusoidal position embeddings - LearnablePositionEncoder: Learned position embeddings

Sequence Encoders: - BaseSequenceEncoder: Abstract base for sequence encoders - RNNSequenceEncoder: Generic RNN encoder (LSTM, GRU, vanilla RNN) - LSTMSequenceEncoder: LSTM-specific encoder - OptimizedLSTMSequenceEncoder: Optimized LSTM encoder - GRUSequenceEncoder: GRU-specific encoder - TransformerSequenceEncoder: Transformer encoder with multi-head attention

Features
  • Bidirectional RNN support
  • Stacked layers with dropout
  • Position encoding for transformers
  • Efficient implementation using scan and vmap
  • Masked sequence processing

Examples:

>>> from formed.integrations.flax.modules import (
...     LSTMSequenceEncoder,
...     TransformerSequenceEncoder,
...     SinusoidalPositionEncoder
... )
>>> from flax import nnx
>>>
>>> # Bidirectional LSTM encoder
>>> encoder = LSTMSequenceEncoder(
...     features=128,
...     num_layers=2,
...     bidirectional=True,
...     dropout=0.1,
...     rngs=nnx.Rngs(0)
... )
>>>
>>> # Transformer encoder with sinusoidal positions
>>> encoder = TransformerSequenceEncoder(
...     features=128,
...     num_heads=8,
...     num_layers=6,
...     position_encoder=SinusoidalPositionEncoder(),
...     rngs=rngs
... )

BasePositionEncoder

Bases: Registrable, ABC

Abstract base class for position encoders.

Position encoders add positional information to input embeddings, allowing models to understand token positions in sequences.

SinusoidalPositionEncoder

SinusoidalPositionEncoder(max_length=512)

Bases: BasePositionEncoder

Sinusoidal position encoding from "Attention Is All You Need".

This encoder uses sine and cosine functions of different frequencies to generate position embeddings without learnable parameters.

PARAMETER DESCRIPTION
max_length

Maximum sequence length to support.

TYPE: int DEFAULT: 512

Examples:

>>> encoder = SinusoidalPositionEncoder(max_length=512)
>>> encoded = encoder(embeddings)
Note

Encodings are cached for efficiency.

Source code in src/formed/integrations/flax/modules/encoders.py
110
111
def __init__(self, max_length: int = 512) -> None:
    self.max_length = max_length

max_length instance-attribute

max_length = max_length

LearnablePositionEncoder

LearnablePositionEncoder(
    features, *, max_length=512, rngs=None
)

Bases: BasePositionEncoder

Learnable position embeddings.

This encoder uses a learned embedding matrix for position encoding, allowing the model to learn task-specific positional patterns.

PARAMETER DESCRIPTION
features

Embedding dimension.

TYPE: int

rngs

Random number generators.

TYPE: Rngs | None DEFAULT: None

max_length

Maximum sequence length to support.

TYPE: int DEFAULT: 512

Examples:

>>> encoder = LearnablePositionEncoder(
...     features=128,
...     max_length=512,
...     rngs=rngs
... )
Source code in src/formed/integrations/flax/modules/encoders.py
147
148
149
150
151
152
153
154
155
def __init__(
    self,
    features: int,
    *,
    max_length: int = 512,
    rngs: nnx.Rngs | None = None,
) -> None:
    rngs = rngs or require_rngs()
    self.embed = nnx.Embed(max_length, features, rngs=rngs)

embed instance-attribute

embed = Embed(max_length, features, rngs=rngs)

BaseSequenceEncoder

Bases: Module, Registrable, ABC

Abstract base class for sequence encoders.

Sequence encoders process sequential data and output contextual representations for each position in the sequence.

get_input_dim abstractmethod

get_input_dim()

Get the expected input dimension.

RETURNS DESCRIPTION
int | None

Input feature dimension or None if dimension-agnostic.

Source code in src/formed/integrations/flax/modules/encoders.py
190
191
192
193
194
195
196
197
198
@abc.abstractmethod
def get_input_dim(self) -> int | None:
    """Get the expected input dimension.

    Returns:
        Input feature dimension or None if dimension-agnostic.

    """
    raise NotImplementedError

get_output_dim abstractmethod

get_output_dim()

Get the output dimension.

RETURNS DESCRIPTION
int | Callable[[int], int]

Output feature dimension or a function mapping input dim to output dim.

Source code in src/formed/integrations/flax/modules/encoders.py
200
201
202
203
204
205
206
207
208
@abc.abstractmethod
def get_output_dim(self) -> int | Callable[[int], int]:
    """Get the output dimension.

    Returns:
        Output feature dimension or a function mapping input dim to output dim.

    """
    raise NotImplementedError

RNNSequenceEncoder

RNNSequenceEncoder(
    cell_factory,
    features,
    num_layers=1,
    bidirectional=False,
    feedforward_layers=None,
    dropout=0.0,
    rngs=None,
)

Bases: BaseSequenceEncoder

Generic RNN-based sequence encoder.

This encoder supports various RNN cell types (LSTM, GRU, vanilla RNN) with optional bidirectionality, multiple layers, and dropout.

PARAMETER DESCRIPTION
cell_factory

Function that creates an RNN cell given rngs.

TYPE: Callable[[Rngs], RNNCellBase]

num_layers

Number of stacked RNN layers.

TYPE: int DEFAULT: 1

bidirectional

Whether to use bidirectional RNN.

TYPE: bool DEFAULT: False

dropout

Dropout rate between layers.

TYPE: float DEFAULT: 0.0

rngs

Random number generators (int or nnx.Rngs).

TYPE: Rngs | None DEFAULT: None

Note

This is a base class. Use LSTMSequenceEncoder, GRUSequenceEncoder, or OptimizedLSTMSequenceEncoder for specific cell types.

Source code in src/formed/integrations/flax/modules/encoders.py
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
def __init__(
    self,
    cell_factory: Callable[[nnx.Rngs], nnx.RNNCellBase],
    features: int,
    num_layers: int = 1,
    bidirectional: bool = False,
    feedforward_layers: int | None = None,
    dropout: float = 0.0,
    rngs: nnx.Rngs | None = None,
) -> None:
    rngs = rngs or require_rngs()

    @nnx.vmap(in_axes=0, out_axes=0)
    def create_block(rngs: nnx.Rngs) -> RNNSequenceEncoder._RNNBlock:
        rnn: nnx.RNN | nnx.Bidirectional
        if bidirectional:
            forward_cell = cell_factory(rngs)
            backward_cell = cell_factory(rngs)
            rnn = nnx.Bidirectional(
                forward_rnn=nnx.RNN(forward_cell, rngs=rngs),
                backward_rnn=nnx.RNN(backward_cell, rngs=rngs),
                rngs=rngs,
            )
        else:
            rnn = nnx.RNN(cell_factory(rngs), rngs=rngs)
        return self._RNNBlock(
            rnn,
            feedforward_layers=feedforward_layers or 0,
            dropout=dropout,
            rngs=rngs,
        )

    self.blocks = create_block(rngs.fork(split=num_layers))
    self.num_layers = num_layers
    self.bidirectional = bidirectional
    self.features = features

blocks instance-attribute

blocks = create_block(fork(split=num_layers))

num_layers instance-attribute

num_layers = num_layers

bidirectional instance-attribute

bidirectional = bidirectional

features instance-attribute

features = features

get_input_dim

get_input_dim()
Source code in src/formed/integrations/flax/modules/encoders.py
338
339
def get_input_dim(self) -> int:
    return self.features

get_output_dim

get_output_dim()
Source code in src/formed/integrations/flax/modules/encoders.py
341
342
def get_output_dim(self) -> int:
    return self.features

LSTMSequenceEncoder

LSTMSequenceEncoder(
    features,
    num_layers=1,
    bidirectional=False,
    feedforward_layers=None,
    dropout=0.0,
    rngs=None,
)

Bases: RNNSequenceEncoder

LSTM-based sequence encoder.

PARAMETER DESCRIPTION
features

Hidden dimension.

TYPE: int

num_layers

Number of LSTM layers.

TYPE: int DEFAULT: 1

bidirectional

Whether to use bidirectional LSTM.

TYPE: bool DEFAULT: False

dropout

Dropout rate between layers.

TYPE: float DEFAULT: 0.0

rngs

Random number generators.

TYPE: Rngs | None DEFAULT: None

Examples:

>>> # Bidirectional 2-layer LSTM
>>> encoder = LSTMSequenceEncoder(
...     features=128,
...     num_layers=2,
...     bidirectional=True,
...     dropout=0.1,
...     rngs=0
... )
Source code in src/formed/integrations/flax/modules/encoders.py
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
def __init__(
    self,
    features: int,
    num_layers: int = 1,
    bidirectional: bool = False,
    feedforward_layers: int | None = None,
    dropout: float = 0.0,
    rngs: nnx.Rngs | None = None,
) -> None:
    rngs = rngs or require_rngs()
    super().__init__(
        cell_factory=lambda rngs: nnx.LSTMCell(features, features, rngs=rngs),
        features=features,
        num_layers=num_layers,
        bidirectional=bidirectional,
        feedforward_layers=feedforward_layers,
        dropout=dropout,
        rngs=rngs,
    )

blocks instance-attribute

blocks = create_block(fork(split=num_layers))

num_layers instance-attribute

num_layers = num_layers

bidirectional instance-attribute

bidirectional = bidirectional

features instance-attribute

features = features

get_input_dim

get_input_dim()
Source code in src/formed/integrations/flax/modules/encoders.py
338
339
def get_input_dim(self) -> int:
    return self.features

get_output_dim

get_output_dim()
Source code in src/formed/integrations/flax/modules/encoders.py
341
342
def get_output_dim(self) -> int:
    return self.features

OptimizedLSTMSequenceEncoder

OptimizedLSTMSequenceEncoder(
    features,
    num_layers=1,
    bidirectional=False,
    feedforward_layers=None,
    dropout=0.0,
    rngs=None,
)

Bases: RNNSequenceEncoder

Optimized LSTM sequence encoder.

Uses Flax's optimized LSTM implementation for better performance.

PARAMETER DESCRIPTION
features

Hidden dimension.

TYPE: int

num_layers

Number of LSTM layers.

TYPE: int DEFAULT: 1

bidirectional

Whether to use bidirectional LSTM.

TYPE: bool DEFAULT: False

dropout

Dropout rate between layers.

TYPE: float DEFAULT: 0.0

rngs

Random number generators.

TYPE: Rngs | None DEFAULT: None

Source code in src/formed/integrations/flax/modules/encoders.py
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
def __init__(
    self,
    features: int,
    num_layers: int = 1,
    bidirectional: bool = False,
    feedforward_layers: int | None = None,
    dropout: float = 0.0,
    rngs: nnx.Rngs | None = None,
) -> None:
    super().__init__(
        cell_factory=lambda rngs: nnx.OptimizedLSTMCell(features, features, rngs=rngs),
        features=features,
        num_layers=num_layers,
        bidirectional=bidirectional,
        feedforward_layers=feedforward_layers,
        dropout=dropout,
        rngs=rngs,
    )

blocks instance-attribute

blocks = create_block(fork(split=num_layers))

num_layers instance-attribute

num_layers = num_layers

bidirectional instance-attribute

bidirectional = bidirectional

features instance-attribute

features = features

get_input_dim

get_input_dim()
Source code in src/formed/integrations/flax/modules/encoders.py
338
339
def get_input_dim(self) -> int:
    return self.features

get_output_dim

get_output_dim()
Source code in src/formed/integrations/flax/modules/encoders.py
341
342
def get_output_dim(self) -> int:
    return self.features

GRUSequenceEncoder

GRUSequenceEncoder(
    features,
    num_layers=1,
    bidirectional=False,
    feedforward_layers=None,
    dropout=0.0,
    rngs=None,
)

Bases: RNNSequenceEncoder

GRU-based sequence encoder.

PARAMETER DESCRIPTION
features

Hidden dimension.

TYPE: int

num_layers

Number of GRU layers.

TYPE: int DEFAULT: 1

bidirectional

Whether to use bidirectional GRU.

TYPE: bool DEFAULT: False

dropout

Dropout rate between layers.

TYPE: float DEFAULT: 0.0

rngs

Random number generators.

TYPE: Rngs | None DEFAULT: None

Examples:

>>> encoder = GRUSequenceEncoder(
...     features=256,
...     num_layers=3,
...     bidirectional=True,
...     rngs=0
... )
Source code in src/formed/integrations/flax/modules/encoders.py
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
def __init__(
    self,
    features: int,
    num_layers: int = 1,
    bidirectional: bool = False,
    feedforward_layers: int | None = None,
    dropout: float = 0.0,
    rngs: nnx.Rngs | None = None,
) -> None:
    super().__init__(
        cell_factory=lambda rngs: nnx.GRUCell(features, features, rngs=rngs),
        features=features,
        num_layers=num_layers,
        bidirectional=bidirectional,
        feedforward_layers=feedforward_layers,
        dropout=dropout,
        rngs=rngs,
    )

blocks instance-attribute

blocks = create_block(fork(split=num_layers))

num_layers instance-attribute

num_layers = num_layers

bidirectional instance-attribute

bidirectional = bidirectional

features instance-attribute

features = features

get_input_dim

get_input_dim()
Source code in src/formed/integrations/flax/modules/encoders.py
338
339
def get_input_dim(self) -> int:
    return self.features

get_output_dim

get_output_dim()
Source code in src/formed/integrations/flax/modules/encoders.py
341
342
def get_output_dim(self) -> int:
    return self.features

TransformerSequenceEncoder

TransformerSequenceEncoder(
    features,
    num_heads,
    *,
    num_layers=1,
    dropout=0.0,
    epsilon=1e-06,
    feedworward_features=None,
    activation=gelu,
    position_encoder=None,
    rngs=None,
)

Bases: BaseSequenceEncoder

Transformer-based sequence encoder.

This encoder uses multi-head self-attention and feed-forward layers to process sequences, following the Transformer architecture.

PARAMETER DESCRIPTION
features

Model dimension.

TYPE: int

num_heads

Number of attention heads.

TYPE: int

num_layers

Number of transformer layers.

TYPE: int DEFAULT: 1

dropout

Dropout rate.

TYPE: float DEFAULT: 0.0

epsilon

Layer normalization epsilon.

TYPE: float DEFAULT: 1e-06

feedworward_features

Feed-forward hidden dimension (defaults to 4*features).

TYPE: int | None DEFAULT: None

activation

Activation function for feed-forward layers.

TYPE: Callable[[Array], Array] DEFAULT: gelu

position_encoder

Optional position encoder.

TYPE: BasePositionEncoder | None DEFAULT: None

rngs

Random number generators.

TYPE: Rngs | None DEFAULT: None

Examples:

>>> # Transformer with sinusoidal positions
>>> encoder = TransformerSequenceEncoder(
...     features=512,
...     num_heads=8,
...     num_layers=6,
...     dropout=0.1,
...     position_encoder=SinusoidalPositionEncoder(),
...     rngs=rngs
... )
>>>
>>> # Transformer with learnable positions
>>> encoder = TransformerSequenceEncoder(
...     features=512,
...     num_heads=8,
...     num_layers=6,
...     position_encoder=LearnablePositionEncoder(512, rngs=rngs),
...     rngs=rngs
... )
Source code in src/formed/integrations/flax/modules/encoders.py
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
def __init__(
    self,
    features: int,
    num_heads: int,
    *,
    num_layers: int = 1,
    dropout: float = 0.0,
    epsilon: float = 1e-6,
    feedworward_features: int | None = None,
    activation: Callable[[jax.Array], jax.Array] = jax.nn.gelu,
    position_encoder: BasePositionEncoder | None = None,
    rngs: nnx.Rngs | None = None,
) -> None:
    rngs = rngs or require_rngs()

    @nnx.vmap(in_axes=0, out_axes=0)
    def create_block(rngs: nnx.Rngs) -> TransformerSequenceEncoder._TransformerBlock:
        return self._TransformerBlock(
            features=features,
            num_heads=num_heads,
            dropout=dropout,
            epsilon=epsilon,
            feedworward_features=feedworward_features or 4 * features,
            activation=activation,
            rngs=rngs,
        )

    self.num_layers = num_layers
    self.blocks = create_block(rngs.fork(split=num_layers))
    self.position_encoder = position_encoder
    self.features = features

num_layers instance-attribute

num_layers = num_layers

blocks instance-attribute

blocks = create_block(fork(split=num_layers))

position_encoder instance-attribute

position_encoder = position_encoder

features instance-attribute

features = features

get_input_dim

get_input_dim()
Source code in src/formed/integrations/flax/modules/encoders.py
601
602
def get_input_dim(self) -> int:
    return self.features

get_output_dim

get_output_dim()
Source code in src/formed/integrations/flax/modules/encoders.py
604
605
def get_output_dim(self) -> int:
    return self.features

formed.integrations.flax.modules.feedforward

Feed-forward neural network modules for Flax models.

This module provides feed-forward network layers with support for multiple layers, dropout, layer normalization, and residual connections.

Key Components
  • Block: Single feed-forward block with optional normalization and dropout
  • FeedForward: Stacked feed-forward blocks with configurable connections
Features
  • Configurable activation functions
  • Layer normalization with custom epsilon
  • Dropout for regularization
  • Dense residual connections
  • Efficient implementation using scan

Examples:

>>> from formed.integrations.flax.modules import FeedForward
>>> from flax import nnx
>>> import jax.nn
>>>
>>> # Simple 3-layer feed-forward network
>>> ffn = FeedForward(
...     features=128,
...     num_layers=3,
...     dropout=0.1,
...     activation=jax.nn.gelu,
...     rngs=nnx.Rngs(0)
... )
>>>
>>> # With dense residual connections
>>> ffn = FeedForward(
...     features=128,
...     num_layers=3,
...     residual_connection="dense",
...     rngs=rngs
... )

Block

Block(
    input_dim,
    output_dim,
    dropout=0.0,
    layer_norm_eps=None,
    activation=relu,
    rngs=None,
)

Bases: Module

Single feed-forward block with optional normalization and dropout.

A block consists of a linear transformation, activation, optional dropout, and optional layer normalization. It can also accept a residual input.

PARAMETER DESCRIPTION
rngs

Random number generators.

TYPE: Rngs | None DEFAULT: None

input_dim

Input dimension.

TYPE: int

output_dim

Output dimension.

TYPE: int

dropout

Dropout rate (0 means no dropout).

TYPE: float DEFAULT: 0.0

layer_norm_eps

Layer normalization epsilon (None means no layer norm).

TYPE: float | None DEFAULT: None

activation

Activation function (default: ReLU).

TYPE: Callable[[Array], Array] DEFAULT: relu

Source code in src/formed/integrations/flax/modules/feedforward.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def __init__(
    self,
    input_dim: int,
    output_dim: int,
    dropout: float = 0.0,
    layer_norm_eps: float | None = None,
    activation: Callable[[jax.Array], jax.Array] = jax.nn.relu,
    rngs: nnx.Rngs | None = None,
) -> None:
    rngs = rngs or require_rngs()
    self.linear = nnx.Linear(input_dim, output_dim, rngs=rngs)
    self.activation = activation
    self.dropout = nnx.Dropout(dropout, rngs=rngs) if dropout > 0.0 else None
    self.layer_norm = (
        nnx.LayerNorm(output_dim, epsilon=layer_norm_eps, rngs=rngs) if layer_norm_eps is not None else None
    )

linear instance-attribute

linear = Linear(input_dim, output_dim, rngs=rngs)

activation instance-attribute

activation = activation

dropout instance-attribute

dropout = (
    Dropout(dropout, rngs=rngs) if dropout > 0.0 else None
)

layer_norm instance-attribute

layer_norm = (
    LayerNorm(output_dim, epsilon=layer_norm_eps, rngs=rngs)
    if layer_norm_eps is not None
    else None
)

FeedForward

FeedForward(
    features,
    num_layers=1,
    dropout=0.0,
    layer_norm_eps=None,
    activation=relu,
    residual_connection="none",
    rngs=None,
)

Bases: Module

Multi-layer feed-forward neural network.

This module stacks multiple feed-forward blocks with configurable activation, dropout, normalization, and residual connections.

PARAMETER DESCRIPTION
features

Hidden dimension for all layers.

TYPE: int

num_layers

Number of layers.

TYPE: int DEFAULT: 1

dropout

Dropout rate applied after each activation.

TYPE: float DEFAULT: 0.0

layer_norm_eps

Epsilon for layer normalization (None disables layer norm).

TYPE: float | None DEFAULT: None

activation

Activation function (default: ReLU).

TYPE: Callable[[Array], Array] DEFAULT: relu

residual_connection

Type of residual connection: - "none": No residual connections (default) - "dense": Dense connections (each layer receives sum of all previous)

TYPE: Literal['none', 'dense'] DEFAULT: 'none'

rngs

Random number generators (can be int seed or nnx.Rngs).

TYPE: Rngs | None DEFAULT: None

Examples:

>>> # Simple 3-layer network
>>> ffn = FeedForward(features=256, num_layers=3, rngs=0)
>>> output = ffn(x)
>>>
>>> # With dropout and layer norm
>>> ffn = FeedForward(
...     features=256,
...     num_layers=3,
...     dropout=0.1,
...     layer_norm_eps=1e-6,
...     rngs=0
... )
>>>
>>> # With dense residual connections
>>> ffn = FeedForward(
...     features=256,
...     num_layers=4,
...     residual_connection="dense",
...     rngs=0
... )
Note

When using residual_connection="dense", each layer receives the sum of outputs from all previous layers, similar to DenseNet.

Source code in src/formed/integrations/flax/modules/feedforward.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def __init__(
    self,
    features: int,
    num_layers: int = 1,
    dropout: float = 0.0,
    layer_norm_eps: float | None = None,
    activation: Callable[[jax.Array], jax.Array] = jax.nn.relu,
    residual_connection: Literal["none", "dense"] = "none",
    rngs: nnx.Rngs | None = None,
) -> None:
    rngs = rngs or require_rngs()

    @nnx.vmap(in_axes=0, out_axes=0)
    def create_block(rngs: nnx.Rngs) -> Block:
        return Block(features, features, dropout, layer_norm_eps, activation, rngs=rngs)

    self.features = features
    self.num_layers = num_layers
    self.blocks = create_block(rngs.fork(split=num_layers))
    self.residual_connection = residual_connection

features instance-attribute

features = features

num_layers instance-attribute

num_layers = num_layers

blocks instance-attribute

blocks = create_block(fork(split=num_layers))

residual_connection instance-attribute

residual_connection = residual_connection

get_input_dim

get_input_dim()
Source code in src/formed/integrations/flax/modules/feedforward.py
167
168
def get_input_dim(self) -> int:
    return self.features

get_output_dim

get_output_dim()
Source code in src/formed/integrations/flax/modules/feedforward.py
170
171
def get_output_dim(self) -> int:
    return self.features

formed.integrations.flax.modules.losses

BaseClassificationLoss

Bases: Module, Registrable, Generic[_ParamsT], ABC

Abstract base class for classification loss functions.

A ClassificationLoss defines a strategy for computing loss based on model logits and true labels.

CLASS TYPE PARAMETER DESCRIPTION
_ParamsT

Type of additional parameters used during loss computation.

CrossEntropyLoss

CrossEntropyLoss(weighter=None, reduce='mean')

Bases: BaseClassificationLoss[_ParamsT]

Cross-entropy loss for classification tasks.

PARAMETER DESCRIPTION
weighter

An optional label weighter to assign weights to each class.

TYPE: BaseLabelWeighter[_ParamsT] | None DEFAULT: None

reduce

Reduction method for loss computation ("mean" or "sum").

TYPE: Literal['mean', 'sum'] DEFAULT: 'mean'

Source code in src/formed/integrations/flax/modules/losses.py
50
51
52
53
54
55
56
57
def __init__(
    self,
    weighter: BaseLabelWeighter[_ParamsT] | None = None,
    reduce: Literal["mean", "sum"] = "mean",
) -> None:
    super().__init__()
    self._weighter = weighter
    self._reduce = reduce

formed.integrations.flax.modules.samplers

BaseLabelSampler

Bases: Module, Registrable, Generic[_ParamsT], ABC

Abstract base class for label samplers.

A LabelSampler defines a strategy for sampling labels based on model logits.

CLASS TYPE PARAMETER DESCRIPTION
_ParamsT

Type of additional parameters used during sampling.

ArgmaxLabelSampler

Bases: BaseLabelSampler[None]

Label sampler that selects the label with the highest logit.

MultinomialLabelSamplerParams

Bases: NamedTuple

Parameters for the MultinomialLabelSampler.

This class can be extended in the future to include additional parameters for sampling if needed.

ATTRIBUTE DESCRIPTION
rngs

Random number generators.

TYPE: Rngs | None

templerature

Sampling temperature to control randomness.

TYPE: float

rngs class-attribute instance-attribute

rngs = None

templerature class-attribute instance-attribute

templerature = 1.0

MultinomialLabelSampler

Bases: BaseLabelSampler[MultinomialLabelSamplerParams]

Label sampler that samples labels from a multinomial distribution defined by the logits.

Params class-attribute instance-attribute

formed.integrations.flax.modules.vectorizers

Sequence vectorization modules for Flax models.

This module provides vectorizers that convert variable-length sequences into fixed-size vectors. Vectorizers apply pooling operations over the sequence dimension to produce single vectors per sequence.

Key Components
  • BaseSequenceVectorizer: Abstract base class for vectorizers
  • BagOfEmbeddingsSequenceVectorizer: Pools sequence embeddings
Features
  • Multiple pooling strategies (mean, max, min, sum, first, last, hier)
  • Masked pooling to ignore padding tokens
  • Optional normalization before pooling
  • Hierarchical pooling with sliding windows

Examples:

>>> from formed.integrations.flax.modules import BagOfEmbeddingsSequenceVectorizer
>>>
>>> # Mean pooling over sequence
>>> vectorizer = BagOfEmbeddingsSequenceVectorizer(pooling="mean")
>>> vector = vectorizer(embeddings, mask=mask)
>>>
>>> # Max pooling with normalization
>>> vectorizer = BagOfEmbeddingsSequenceVectorizer(
...     pooling="max",
...     normalize=True
... )

BaseSequenceVectorizer

Bases: Module, Registrable, ABC

Abstract base class for sequence vectorizers.

Vectorizers convert variable-length sequences into fixed-size vectors by applying pooling operations over the sequence dimension.

get_input_dim abstractmethod

get_input_dim()

Get the expected input dimension.

RETURNS DESCRIPTION
int | None

Input dimension or None if dimension-agnostic.

Source code in src/formed/integrations/flax/modules/vectorizers.py
69
70
71
72
73
74
75
76
77
@abc.abstractmethod
def get_input_dim(self) -> int | None:
    """Get the expected input dimension.

    Returns:
        Input dimension or None if dimension-agnostic.

    """
    raise NotImplementedError

get_output_dim abstractmethod

get_output_dim()

Get the output dimension.

RETURNS DESCRIPTION
int | Callable[[int], int]

Output feature dimension or a function mapping input dim to output dim.

Source code in src/formed/integrations/flax/modules/vectorizers.py
79
80
81
82
83
84
85
86
87
@abc.abstractmethod
def get_output_dim(self) -> int | Callable[[int], int]:
    """Get the output dimension.

    Returns:
        Output feature dimension or a function mapping input dim to output dim.

    """
    raise NotImplementedError

BagOfEmbeddingsSequenceVectorizer

BagOfEmbeddingsSequenceVectorizer(
    pooling="mean", normalize=False, window_size=None
)

Bases: BaseSequenceVectorizer

Bag-of-embeddings vectorizer using pooling operations.

This vectorizer applies pooling over the sequence dimension to create fixed-size vectors. Multiple pooling strategies are supported, and padding tokens are properly masked during pooling.

PARAMETER DESCRIPTION
pooling

Pooling strategy to use: - "mean": Average pooling (default) - "max": Max pooling - "min": Min pooling - "sum": Sum pooling - "first": Take first token - "last": Take last non-padding token - "hier": Hierarchical pooling with sliding window

TYPE: PoolingMethod | Sequence[PoolingMethod] DEFAULT: 'mean'

normalize

Whether to L2-normalize embeddings before pooling.

TYPE: bool DEFAULT: False

window_size

Window size for hierarchical pooling (required if pooling="hier").

TYPE: int | None DEFAULT: None

Examples:

>>> # Mean pooling
>>> vectorizer = BagOfEmbeddingsSequenceVectorizer(pooling="mean")
>>> vector = vectorizer(embeddings, mask=mask)
>>>
>>> # Max pooling with normalization
>>> vectorizer = BagOfEmbeddingsSequenceVectorizer(
...     pooling="max",
...     normalize=True
... )
>>>
>>> # Hierarchical pooling
>>> vectorizer = BagOfEmbeddingsSequenceVectorizer(
...     pooling="hier",
...     window_size=3
... )
Note

This vectorizer is dimension-agnostic - it preserves the embedding dimension from input to output.

Source code in src/formed/integrations/flax/modules/vectorizers.py
134
135
136
137
138
139
140
141
142
def __init__(
    self,
    pooling: PoolingMethod | Sequence[PoolingMethod] = "mean",
    normalize: bool = False,
    window_size: int | None = None,
) -> None:
    self._pooling: PoolingMethod | Sequence[PoolingMethod] = pooling
    self._normalize = normalize
    self._window_size = window_size

get_input_dim

get_input_dim()
Source code in src/formed/integrations/flax/modules/vectorizers.py
158
159
def get_input_dim(self) -> None:
    return None

get_output_dim

get_output_dim()
Source code in src/formed/integrations/flax/modules/vectorizers.py
161
162
163
def get_output_dim(self) -> Callable[[int], int]:
    num_pooling = 1 if isinstance(self._pooling, str) else len(self._pooling)
    return lambda input_dim: input_dim * num_pooling

formed.integrations.flax.modules.weighters

BaseLabelWeighter

Bases: Module, Registrable, Generic[_ParamsT], ABC

Abstract base class for label weighters.

A LabelWeighter defines a strategy for assigning weights to each label based on model logits and true targets.

CLASS TYPE PARAMETER DESCRIPTION
_ParamsT

Type of additional parameters used during weighting.

StaticLabelWeighter

StaticLabelWeighter(weights)

Bases: BaseLabelWeighter

Label weighter that assigns static weights to each class.

PARAMETER DESCRIPTION
weights

An array of shape (num_classes,) containing the weight for each class.

TYPE: ArrayCompatible

Source code in src/formed/integrations/flax/modules/weighters.py
53
54
55
def __init__(self, weights: ArrayCompatible) -> None:
    super().__init__()
    self._weights = nnx.Param(ensure_jax_array(weights), mutable=False)

BalancedByDistributionLabelWeighter

BalancedByDistributionLabelWeighter(
    distribution, eps=1e-08
)

Bases: BaseLabelWeighter

Label weighter that balances classes based on their distribution.

PARAMETER DESCRIPTION
distribution

An array of shape (num_classes,) representing the class distribution.

TYPE: ArrayCompatible

eps

A small epsilon value to avoid division by zero.

TYPE: float DEFAULT: 1e-08

Source code in src/formed/integrations/flax/modules/weighters.py
76
77
78
def __init__(self, distribution: ArrayCompatible, eps: float = 1e-8) -> None:
    self._distribution = nnx.Param(ensure_jax_array(distribution), mutable=False)
    self._eps = eps

formed.integrations.flax.training.callbacks

Training callbacks for monitoring and controlling Flax model training.

This module provides a callback system for Flax training, allowing custom logic to be executed at various points in the training loop. Callbacks can monitor metrics, save checkpoints, implement early stopping, and integrate with experiment tracking systems.

Key Components
  • FlaxTrainingCallback: Base class for all callbacks
  • EvaluationCallback: Computes metrics using custom evaluators
  • EarlyStoppingCallback: Stops training based on metric improvements
  • MlflowCallback: Logs metrics to MLflow
Features
  • Hook points at training/epoch/batch start and end
  • Metric computation and logging
  • Model checkpointing
  • Early stopping with patience
  • MLflow integration
  • Extensible for custom callbacks

Examples:

>>> from formed.integrations.flax import (
...     FlaxTrainer,
...     EarlyStoppingCallback,
...     EvaluationCallback,
...     MlflowCallback
... )
>>>
>>> trainer = FlaxTrainer(
...     train_dataloader=train_loader,
...     val_dataloader=val_loader,
...     callbacks=[
...         EvaluationCallback(my_evaluator),
...         EarlyStoppingCallback(patience=5, metric="-loss"),
...         MlflowCallback()
...     ]
... )

FlaxTrainingCallback

Bases: Registrable

Base class for training callbacks.

Callbacks provide hooks to execute custom logic at various points during training. Subclasses can override any hook method to implement custom behavior such as logging, checkpointing, or early stopping.

Hook execution order
  1. on_training_start - once at the beginning
  2. on_epoch_start - at the start of each epoch
  3. on_batch_start - before each training batch
  4. on_batch_end - after each training batch
  5. on_eval_start - before evaluation (returns evaluator)
  6. on_eval_end - after evaluation with computed metrics
  7. on_log - when metrics are logged
  8. on_epoch_end - at the end of each epoch
  9. on_training_end - once at the end (can modify final state)

Examples:

>>> @FlaxTrainingCallback.register("my_callback")
... class MyCallback(FlaxTrainingCallback):
...     def on_epoch_end(self, trainer, model, state, epoch):
...         print(f"Completed epoch {epoch} at step {state.step}")

on_training_start

on_training_start(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
88
89
90
91
92
93
94
def on_training_start(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> None:
    pass

on_training_end

on_training_end(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
 96
 97
 98
 99
100
101
102
def on_training_end(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> TrainState:
    return state

on_epoch_start

on_epoch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
104
105
106
107
108
109
110
111
def on_epoch_start(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_epoch_end

on_epoch_end(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
113
114
115
116
117
118
119
120
def on_epoch_end(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_batch_start

on_batch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
122
123
124
125
126
127
128
129
def on_batch_start(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_batch_end

on_batch_end(trainer, model, state, epoch, output)
Source code in src/formed/integrations/flax/training/callbacks.py
131
132
133
134
135
136
137
138
139
def on_batch_end(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
    output: ModelOutputT,
) -> None:
    pass

on_eval_start

on_eval_start(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def on_eval_start(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> IEvaluator[ModelInputT, ModelOutputT]:
    class DummyMetric(IEvaluator):
        def update(self, inputs, output, /) -> None:
            pass

        def compute(self) -> dict[str, float]:
            return {}

        def reset(self) -> None:
            pass

        def clone(self) -> Self:
            return self

    return DummyMetric()

on_eval_end

on_eval_end(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/flax/training/callbacks.py
162
163
164
165
166
167
168
169
170
def on_eval_end(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    metrics: Mapping[str, float],
    prefix: str = "",
) -> None:
    pass

on_log

on_log(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/flax/training/callbacks.py
172
173
174
175
176
177
178
179
180
def on_log(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    metrics: Mapping[str, float],
    prefix: str = "",
) -> None:
    pass

EvaluationCallback

EvaluationCallback(evaluator)

Bases: FlaxTrainingCallback, Generic[ModelInputT, ModelOutputT]

Callback for computing metrics using a custom evaluator.

This callback integrates a custom evaluator into the training loop, resetting it before each evaluation phase and returning it for metric accumulation.

PARAMETER DESCRIPTION
evaluator

Evaluator implementing the IEvaluator protocol.

TYPE: IEvaluator[ModelInputT, ModelOutputT]

Examples:

>>> from formed.integrations.ml.metrics import MulticlassAccuracy
>>>
>>> evaluator = MulticlassAccuracy()
>>> callback = EvaluationCallback(evaluator)
Source code in src/formed/integrations/flax/training/callbacks.py
202
203
def __init__(self, evaluator: IEvaluator[ModelInputT, ModelOutputT]) -> None:
    self._evaluator = evaluator

on_eval_start

on_eval_start(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
205
206
207
208
209
210
211
212
def on_eval_start(  # pyright: ignore[reportIncompatibleMethodOverride]
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> IEvaluator[ModelInputT, ModelOutputT]:
    self._evaluator.reset()
    return self._evaluator

on_training_start

on_training_start(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
88
89
90
91
92
93
94
def on_training_start(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> None:
    pass

on_training_end

on_training_end(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
 96
 97
 98
 99
100
101
102
def on_training_end(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> TrainState:
    return state

on_epoch_start

on_epoch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
104
105
106
107
108
109
110
111
def on_epoch_start(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_epoch_end

on_epoch_end(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
113
114
115
116
117
118
119
120
def on_epoch_end(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_batch_start

on_batch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
122
123
124
125
126
127
128
129
def on_batch_start(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_batch_end

on_batch_end(trainer, model, state, epoch, output)
Source code in src/formed/integrations/flax/training/callbacks.py
131
132
133
134
135
136
137
138
139
def on_batch_end(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
    output: ModelOutputT,
) -> None:
    pass

on_eval_end

on_eval_end(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/flax/training/callbacks.py
162
163
164
165
166
167
168
169
170
def on_eval_end(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    metrics: Mapping[str, float],
    prefix: str = "",
) -> None:
    pass

on_log

on_log(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/flax/training/callbacks.py
172
173
174
175
176
177
178
179
180
def on_log(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    metrics: Mapping[str, float],
    prefix: str = "",
) -> None:
    pass

EarlyStoppingCallback

EarlyStoppingCallback(patience=5, metric='-loss')

Bases: FlaxTrainingCallback

Callback for early stopping based on metric improvements.

This callback monitors a specified metric and stops training if it doesn't improve for a given number of evaluations (patience). The best model is automatically saved and restored at the end of training.

PARAMETER DESCRIPTION
patience

Number of evaluations without improvement before stopping.

TYPE: int DEFAULT: 5

metric

Metric to monitor. Prefix with "-" to maximize (e.g., "-loss"), or "+" to minimize (e.g., "+error"). Default is "-loss".

TYPE: str DEFAULT: '-loss'

Examples:

>>> # Stop if validation loss doesn't improve for 5 evaluations
>>> callback = EarlyStoppingCallback(patience=5, metric="-val/loss")
>>>
>>> # Stop if accuracy doesn't improve for 3 evaluations
>>> callback = EarlyStoppingCallback(patience=3, metric="+accuracy")
Note

The best model is saved to the step working directory and automatically restored when training ends early or completes.

Source code in src/formed/integrations/flax/training/callbacks.py
241
242
243
244
245
246
247
248
249
250
251
def __init__(
    self,
    patience: int = 5,
    metric: str = "-loss",
) -> None:
    self._patience = patience
    self._metric = metric.lstrip("-+")
    self._direction = -1 if metric.startswith("-") else 1
    self._best_metric = -float("inf")
    self._best_step: int | None = None
    self._counter = 0

on_training_start

on_training_start(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
266
267
268
269
270
271
272
273
def on_training_start(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> None:
    self._best_metric = -float("inf")
    self._counter = 0

on_eval_end

on_eval_end(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/flax/training/callbacks.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
def on_eval_end(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    metrics: Mapping[str, float],
    prefix: str = "",
) -> None:
    logger = use_step_logger(__name__)
    if prefix:
        metrics = {f"{prefix}{key}": value for key, value in metrics.items()}
    try:
        metric = self._direction * metrics[self._metric]
    except KeyError:
        return

    if metric > self._best_metric:
        self._best_metric = metric
        self._best_step = int(state.step)
        self._counter = 0
        self._checkpointer.save(int(state.step), state)
        logger.info(f"New best model saved with {self._metric}={self._best_metric:.4f}")
    else:
        self._counter += 1
        if self._counter >= self._patience:
            raise StopEarly()

on_training_end

on_training_end(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
302
303
304
305
306
307
308
309
310
311
312
def on_training_end(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> TrainState:
    logger = use_step_logger(__name__)
    if self._best_step is not None:
        logger.info("Restoring best state from early stopping checkpoint.")
        state = cast(TrainState, self._checkpointer.restore(self._best_step, items=state))
    return state

on_epoch_start

on_epoch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
104
105
106
107
108
109
110
111
def on_epoch_start(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_epoch_end

on_epoch_end(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
113
114
115
116
117
118
119
120
def on_epoch_end(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_batch_start

on_batch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
122
123
124
125
126
127
128
129
def on_batch_start(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_batch_end

on_batch_end(trainer, model, state, epoch, output)
Source code in src/formed/integrations/flax/training/callbacks.py
131
132
133
134
135
136
137
138
139
def on_batch_end(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
    output: ModelOutputT,
) -> None:
    pass

on_eval_start

on_eval_start(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def on_eval_start(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> IEvaluator[ModelInputT, ModelOutputT]:
    class DummyMetric(IEvaluator):
        def update(self, inputs, output, /) -> None:
            pass

        def compute(self) -> dict[str, float]:
            return {}

        def reset(self) -> None:
            pass

        def clone(self) -> Self:
            return self

    return DummyMetric()

on_log

on_log(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/flax/training/callbacks.py
172
173
174
175
176
177
178
179
180
def on_log(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    metrics: Mapping[str, float],
    prefix: str = "",
) -> None:
    pass

MlflowCallback

MlflowCallback()

Bases: FlaxTrainingCallback

Callback for logging metrics to MLflow.

This callback automatically logs training and validation metrics to MLflow when used within a workflow step that has MLflow tracking enabled.

Examples:

>>> from formed.integrations.flax import FlaxTrainer, MlflowCallback
>>>
>>> trainer = FlaxTrainer(
...     train_dataloader=train_loader,
...     val_dataloader=val_loader,
...     callbacks=[MlflowCallback()]
... )
Note

Requires the formed mlflow integration and must be used within a workflow step with MLflow tracking configured.

Source code in src/formed/integrations/flax/training/callbacks.py
337
338
339
340
def __init__(self) -> None:
    from formed.integrations.mlflow.workflow import MlflowLogger

    self._mlflow_logger: MlflowLogger | None = None

on_training_start

on_training_start(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
def on_training_start(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> None:
    from formed.integrations.mlflow.workflow import use_mlflow_logger
    from formed.workflow import use_step_logger

    logger = use_step_logger(__name__)

    self._mlflow_logger = use_mlflow_logger()
    if self._mlflow_logger is None:
        logger.warning("MlflowLogger not found. Skipping logging.")

on_log

on_log(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/flax/training/callbacks.py
357
358
359
360
361
362
363
364
365
366
367
368
def on_log(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    metrics: Mapping[str, float],
    prefix: str = "",
) -> None:
    metrics = {prefix + key: value for key, value in metrics.items()}
    if self._mlflow_logger is not None:
        for key, value in metrics.items():
            self._mlflow_logger.log_metric(key, value, step=int(state.step))

on_epoch_end

on_epoch_end(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
370
371
372
373
374
375
376
377
378
def on_epoch_end(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    if self._mlflow_logger is not None:
        self._mlflow_logger.log_metric("epoch", epoch, step=int(state.step))

on_training_end

on_training_end(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
 96
 97
 98
 99
100
101
102
def on_training_end(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> TrainState:
    return state

on_epoch_start

on_epoch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
104
105
106
107
108
109
110
111
def on_epoch_start(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_batch_start

on_batch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
122
123
124
125
126
127
128
129
def on_batch_start(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_batch_end

on_batch_end(trainer, model, state, epoch, output)
Source code in src/formed/integrations/flax/training/callbacks.py
131
132
133
134
135
136
137
138
139
def on_batch_end(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
    output: ModelOutputT,
) -> None:
    pass

on_eval_start

on_eval_start(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def on_eval_start(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> IEvaluator[ModelInputT, ModelOutputT]:
    class DummyMetric(IEvaluator):
        def update(self, inputs, output, /) -> None:
            pass

        def compute(self) -> dict[str, float]:
            return {}

        def reset(self) -> None:
            pass

        def clone(self) -> Self:
            return self

    return DummyMetric()

on_eval_end

on_eval_end(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/flax/training/callbacks.py
162
163
164
165
166
167
168
169
170
def on_eval_end(
    self,
    trainer: "FlaxTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    metrics: Mapping[str, float],
    prefix: str = "",
) -> None:
    pass

formed.integrations.flax.training.engine

Training engine abstractions for Flax models.

This module provides the training engine abstraction that defines how models are trained and evaluated. Engines handle loss computation, gradient calculation, and parameter updates.

Key Components
  • FlaxTrainingEngine: Abstract base class for training engines
  • DefaultFlaxTrainingEngine: Default implementation with automatic differentiation
Features
  • Customizable loss functions
  • Automatic gradient computation using JAX
  • State creation and management
  • Separate train and eval steps
  • Compatible with FlaxTrainer and distributors

Examples:

>>> from formed.integrations.flax import DefaultFlaxTrainingEngine
>>>
>>> # Create engine with custom loss accessor
>>> engine = DefaultFlaxTrainingEngine(loss="total_loss")
>>>
>>> # Or with custom loss function
>>> def custom_loss(output):
...     return output.loss + 0.1 * output.regularization
>>> engine = DefaultFlaxTrainingEngine(loss=custom_loss)

FlaxTrainingEngine

Bases: ABC, Registrable, Generic[ModelInputT, ModelOutputT, ModelParamsT]

Abstract base class for Flax training engines.

A training engine defines how models are trained by implementing state creation, training steps, and evaluation steps. This allows for custom training loops and loss computations.

CLASS TYPE PARAMETER DESCRIPTION
ModelInputT

Type of model input.

ModelOutputT

Type of model output.

ModelParamsT

Type of additional parameters.

create_state abstractmethod

create_state(rngs, trainer, model)

Create initial training state from model and trainer.

PARAMETER DESCRIPTION
rngs

Random number generators.

TYPE: Rngs

trainer

Trainer instance.

TYPE: FlaxTrainer[Any, ModelInputT, ModelOutputT, ModelParamsT]

model

Model to train.

TYPE: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT]

RETURNS DESCRIPTION
TrainState

Initial training state.

Source code in src/formed/integrations/flax/training/engine.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
@abc.abstractmethod
def create_state(
    self,
    rngs: nnx.Rngs,
    trainer: "FlaxTrainer[Any, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
) -> TrainState:
    """Create initial training state from model and trainer.

    Args:
        rngs: Random number generators.
        trainer: Trainer instance.
        model: Model to train.

    Returns:
        Initial training state.

    """
    raise NotImplementedError

train_step abstractmethod

train_step(inputs, state, trainer)

Execute a single training step.

PARAMETER DESCRIPTION
inputs

Batch of training inputs.

TYPE: ModelInputT

state

Current training state.

TYPE: TrainState

trainer

Trainer instance.

TYPE: FlaxTrainer[Any, ModelInputT, ModelOutputT, ModelParamsT]

RETURNS DESCRIPTION
tuple[TrainState, ModelOutputT]

Tuple of (updated_state, model_output).

Source code in src/formed/integrations/flax/training/engine.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
@abc.abstractmethod
def train_step(
    self,
    inputs: ModelInputT,
    state: TrainState,
    trainer: "FlaxTrainer[Any, ModelInputT, ModelOutputT, ModelParamsT]",
) -> tuple[TrainState, ModelOutputT]:
    """Execute a single training step.

    Args:
        inputs: Batch of training inputs.
        state: Current training state.
        trainer: Trainer instance.

    Returns:
        Tuple of (updated_state, model_output).

    """
    raise NotImplementedError

eval_step abstractmethod

eval_step(inputs, state, trainer)

Execute a single evaluation step.

PARAMETER DESCRIPTION
inputs

Batch of evaluation inputs.

TYPE: ModelInputT

state

Current training state.

TYPE: TrainState

trainer

Trainer instance.

TYPE: FlaxTrainer[Any, ModelInputT, ModelOutputT, ModelParamsT]

RETURNS DESCRIPTION
ModelOutputT

Model output.

Source code in src/formed/integrations/flax/training/engine.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
@abc.abstractmethod
def eval_step(
    self,
    inputs: ModelInputT,
    state: TrainState,
    trainer: "FlaxTrainer[Any, ModelInputT, ModelOutputT, ModelParamsT]",
) -> ModelOutputT:
    """Execute a single evaluation step.

    Args:
        inputs: Batch of evaluation inputs.
        state: Current training state.
        trainer: Trainer instance.

    Returns:
        Model output.

    """
    raise NotImplementedError

DefaultFlaxTrainingEngine

DefaultFlaxTrainingEngine(
    loss="loss", optimizer=adamw(0.001)
)

Bases: FlaxTrainingEngine[ModelInputT, ModelOutputT, ModelParamsT]

Default training engine using automatic differentiation.

This engine computes gradients using JAX's automatic differentiation and updates parameters using the provided optimizer. Loss is extracted from model output either by attribute name or custom function.

PARAMETER DESCRIPTION
loss

Loss accessor - either attribute name (e.g., "loss") or callable that extracts loss from model output.

TYPE: str | Callable[[ModelOutputT], Array] DEFAULT: 'loss'

optimizer

Optax optimizer or transformation.

TYPE: IOptimizer | MultiSteps | GradientTransformation DEFAULT: adamw(0.001)

Examples:

>>> # Use output.loss attribute
>>> engine = DefaultFlaxTrainingEngine(loss="loss")
>>>
>>> # Use custom loss function
>>> engine = DefaultFlaxTrainingEngine(
...     loss=lambda output: output.loss + 0.01 * output.regularization
... )
Source code in src/formed/integrations/flax/training/engine.py
150
151
152
153
154
155
156
157
158
159
160
def __init__(
    self,
    loss: str | Callable[[ModelOutputT], jax.Array] = "loss",
    optimizer: IOptimizer | optax.MultiSteps | optax.GradientTransformation = optax.adamw(1e-3),
) -> None:
    if not isinstance(optimizer, optax.GradientTransformation):
        optimizer = optax.GradientTransformation(optimizer.init, optimizer.update)  # pyright: ignore[reportArgumentType]

    super().__init__()
    self._loss = partial(xgetattr, name=loss) if isinstance(loss, str) else loss
    self._optimizer = optimizer

create_state

create_state(rngs, trainer, model)
Source code in src/formed/integrations/flax/training/engine.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def create_state(
    self,
    rngs: nnx.Rngs,
    trainer: "FlaxTrainer[Any, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
) -> TrainState:
    graphdef, params, *states = nnx.split(model, nnx.Param, nnx.BatchStat, nnx.RngState)
    return cast(
        TrainState,
        TrainState.create(
            apply_fn=None,
            graphdef=graphdef,
            additional_states=tuple(states),
            params=params,
            tx=self._optimizer,
        ),
    )

train_step

train_step(inputs, state, trainer)
Source code in src/formed/integrations/flax/training/engine.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def train_step(
    self,
    inputs: ModelInputT,
    state: TrainState,
    trainer: "FlaxTrainer[Any, ModelInputT, ModelOutputT, ModelParamsT]",
) -> tuple[TrainState, ModelOutputT]:
    def step(state: TrainState, inputs: ModelInputT) -> tuple[TrainState, ModelOutputT]:
        model = nnx.merge(state.graphdef, state.params, *state.additional_states)
        model.train()

        def loss_fn(
            model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
        ) -> tuple[jax.Array, ModelOutputT]:
            output = model(inputs)
            loss = self._loss(output)
            return loss, output

        (_, output), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)

        graphdef, params, *additional_states = nnx.split(model, nnx.Param, nnx.BatchStat, nnx.RngState)

        grads = trainer.distributor.reduce(grads)
        state = state.replace(
            graphdef=graphdef,
            params=params,
            additional_states=tuple(additional_states),
        )
        state = state.apply_gradients(grads=grads)
        return state, output

    return step(state, inputs)

eval_step

eval_step(inputs, state, trainer)
Source code in src/formed/integrations/flax/training/engine.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def eval_step(
    self,
    inputs: ModelInputT,
    state: TrainState,
    trainer: "FlaxTrainer[Any, ModelInputT, ModelOutputT, ModelParamsT]",
) -> ModelOutputT:
    del trainer

    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT] = nnx.merge(
        state.graphdef,
        state.params,
        *state.additional_states,
    )
    model.eval()
    return model(inputs)

formed.integrations.flax.training.exceptions

StopEarly

Bases: Exception

Raised to stop training early.

formed.integrations.flax.training.state

Training state management for Flax NNX models.

This module extends Flax's TrainState to support NNX models by storing the graph definition and additional states (BatchStat, RngState) separately from trainable parameters.

Node module-attribute

Node = TypeVar('Node')

TrainState

Bases: TrainState, Generic[Node]

Extended training state for Flax NNX models.

This class extends flax.training.train_state.TrainState to work with Flax NNX models, storing the model's graph definition and additional states (like batch statistics and RNG states) separately from parameters.

ATTRIBUTE DESCRIPTION
graphdef

NNX graph definition describing the model structure.

TYPE: GraphDef[Node]

additional_states

Tuple of additional states (BatchStat, RngState, etc.).

TYPE: tuple[State, ...]

params

Trainable parameters (inherited from TrainState).

TYPE: tuple[State, ...]

opt_state

Optimizer state (inherited from TrainState).

TYPE: tuple[State, ...]

step

Training step counter (inherited from TrainState).

TYPE: tuple[State, ...]

tx

Optimizer transformation (inherited from TrainState).

TYPE: tuple[State, ...]

Examples:

>>> # Create state from model
>>> graphdef, params, *states = nnx.split(model, nnx.Param, nnx.BatchStat)
>>> state = TrainState.create(
...     apply_fn=None,
...     graphdef=graphdef,
...     additional_states=tuple(states),
...     params=params,
...     tx=optimizer
... )
>>>
>>> # Reconstruct model
>>> model = nnx.merge(state.graphdef, state.params, *state.additional_states)

graphdef instance-attribute

graphdef

additional_states class-attribute instance-attribute

additional_states = ()

formed.integrations.flax.training.trainer

High-level trainer for Flax models.

This module provides the FlaxTrainer class, which orchestrates the complete training process for Flax models including data loading, optimization, evaluation, callbacks, and distributed training.

Key Features
  • Flexible training loop with epoch and step-based logging/evaluation
  • Support for callbacks at various training stages
  • Distributed training via data parallelism
  • Rich progress bars with training metrics
  • Early stopping and checkpointing
  • MLflow integration

Examples:

>>> from formed.integrations.flax import (
...     FlaxTrainer,
...     EvaluationCallback,
...     EarlyStoppingCallback
... )
>>> from formed.integrations.ml import DataLoader, BasicBatchSampler
>>> import optax
>>>
>>> # Setup data loaders
>>> train_dataloader = DataLoader(
...     sampler=BasicBatchSampler(batch_size=32, shuffle=True),
...     collator=datamodule.batch
... )
>>>
>>> # Create trainer
>>> trainer = FlaxTrainer(
...     train_dataloader=train_dataloader,
...     val_dataloader=val_dataloader,
...     optimizer=optax.adamw(learning_rate=1e-3),
...     max_epochs=10,
...     callbacks=[
...         EvaluationCallback(my_evaluator),
...         EarlyStoppingCallback(patience=3)
...     ]
... )
>>>
>>> # Train model
>>> rngs = nnx.Rngs(42)
>>> state = trainer.train(rngs, model, train_dataset, val_dataset)

FlaxTrainer

FlaxTrainer(
    *,
    train_dataloader,
    val_dataloader=None,
    engine=None,
    callbacks=(),
    distributor=None,
    max_epochs=10,
    eval_strategy="epoch",
    eval_interval=1,
    logging_strategy="epoch",
    logging_interval=1,
    logging_first_step=True,
    train_prefix="train/",
    val_prefix="val/",
)

Bases: Generic[ItemT, ModelInputT, ModelOutputT, ModelParamsT]

High-level trainer for Flax models.

FlaxTrainer provides a complete training loop with support for distributed training, callbacks, evaluation, and metric logging. It handles the coordination of data loading, model training, evaluation, and callback execution.

CLASS TYPE PARAMETER DESCRIPTION
ItemT

Type of raw dataset items.

ModelInputT

Type of batched model inputs.

ModelOutputT

Type of model outputs.

ModelParamsT

Type of additional model parameters.

PARAMETER DESCRIPTION
train_dataloader

Data loader for training dataset.

TYPE: IDataLoader[ItemT, ModelInputT]

val_dataloader

Optional data loader for validation dataset.

TYPE: IDataLoader[ItemT, ModelInputT] | None DEFAULT: None

engine

Training engine (defaults to DefaultFlaxTrainingEngine).

TYPE: FlaxTrainingEngine[ModelInputT, ModelOutputT, ModelParamsT] | None DEFAULT: None

callbacks

Sequence of training callbacks.

TYPE: Sequence[FlaxTrainingCallback] DEFAULT: ()

distributor

Device distributor (defaults to SingleDeviceDistributor).

TYPE: BaseDistributor | None DEFAULT: None

max_epochs

Maximum number of training epochs.

TYPE: int DEFAULT: 10

eval_strategy

When to evaluate - "epoch" or "step".

TYPE: Literal['epoch', 'step'] DEFAULT: 'epoch'

eval_interval

Evaluation interval (epochs or steps).

TYPE: int DEFAULT: 1

logging_strategy

When to log - "epoch" or "step".

TYPE: Literal['epoch', 'step'] DEFAULT: 'epoch'

logging_interval

Logging interval (epochs or steps).

TYPE: int DEFAULT: 1

logging_first_step

Whether to log after the first training step.

TYPE: bool DEFAULT: True

Examples:

>>> trainer = FlaxTrainer(
...     train_dataloader=train_loader,
...     val_dataloader=val_loader,
...     max_epochs=10,
...     eval_strategy="epoch",
...     logging_strategy="step",
...     logging_interval=100
...     engine=DefaultFlaxTrainingEngine(
...         optimizer=optax.adamw(1e-3),
...     ),
... )
Source code in src/formed/integrations/flax/training/trainer.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def __init__(
    self,
    *,
    train_dataloader: IDataLoader[ItemT, ModelInputT],
    val_dataloader: IDataLoader[ItemT, ModelInputT] | None = None,
    engine: FlaxTrainingEngine[ModelInputT, ModelOutputT, ModelParamsT] | None = None,
    callbacks: Sequence[FlaxTrainingCallback] = (),
    distributor: BaseDistributor | None = None,
    max_epochs: int = 10,
    eval_strategy: Literal["epoch", "step"] = "epoch",
    eval_interval: int = 1,
    logging_strategy: Literal["epoch", "step"] = "epoch",
    logging_interval: int = 1,
    logging_first_step: bool = True,
    train_prefix: str = "train/",
    val_prefix: str = "val/",
) -> None:
    self._train_dataloader = train_dataloader
    self._val_dataloader = val_dataloader
    self._engine = engine or DefaultFlaxTrainingEngine[ModelInputT, ModelOutputT, ModelParamsT]()
    self._distributor = distributor or SingleDeviceDistributor()
    self._max_epochs = max_epochs
    self._eval_strategy = eval_strategy
    self._eval_interval = eval_interval
    self._logging_strategy = logging_strategy
    self._logging_interval = logging_interval
    self._logging_first_step = logging_first_step
    self._callbacks = callbacks
    self._train_prefix = train_prefix
    self._val_prefix = val_prefix

distributor property

distributor

train

train(
    model,
    train_dataset,
    val_dataset=None,
    state=None,
    rngs=None,
)

Train a model on the provided datasets.

PARAMETER DESCRIPTION
model

Model to train.

TYPE: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT]

train_dataset

Sequence of training items.

TYPE: Sequence[ItemT]

val_dataset

Optional sequence of validation items.

TYPE: Sequence[ItemT] | None DEFAULT: None

state

Optional pre-initialized training state (for resuming).

TYPE: TrainState | None DEFAULT: None

rngs

Optional random number generators for initialization.

TYPE: Rngs | None DEFAULT: None

RETURNS DESCRIPTION
TrainState

Final training state with trained parameters.

RAISES DESCRIPTION
ValueError

If val_dataset is provided but val_dataloader is not.

Examples:

>>> state = trainer.train(
...     model, train_items, val_items
... )
>>> # Reconstruct trained model
>>> trained_model = nnx.merge(
...     state.graphdef, state.params, *state.additional_states
... )
Source code in src/formed/integrations/flax/training/trainer.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
def train(
    self,
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    train_dataset: Sequence[ItemT],
    val_dataset: Sequence[ItemT] | None = None,
    state: TrainState | None = None,
    rngs: nnx.Rngs | None = None,
) -> TrainState:
    """Train a model on the provided datasets.

    Args:
        model: Model to train.
        train_dataset: Sequence of training items.
        val_dataset: Optional sequence of validation items.
        state: Optional pre-initialized training state (for resuming).
        rngs: Optional random number generators for initialization.

    Returns:
        Final training state with trained parameters.

    Raises:
        ValueError: If val_dataset is provided but val_dataloader is not.

    Examples:
        >>> state = trainer.train(
        ...     model, train_items, val_items
        ... )
        >>> # Reconstruct trained model
        >>> trained_model = nnx.merge(
        ...     state.graphdef, state.params, *state.additional_states
        ... )

    """
    if val_dataset is not None and self._val_dataloader is None:
        raise ValueError("Validation dataloader is not provided.")

    rngs = rngs or require_rngs()
    logger = use_step_logger(__name__)

    if state is None:
        state = self._engine.create_state(rngs, self, model)

    train_step = self._distributor.map(self._engine.train_step, static_argnums=(2,))
    eval_step = partial(nnx.jit, static_argnames=("trainer"))(self._engine.eval_step)  # pyright: ignore[reportArgumentType]

    for callback in self._callbacks:
        callback.on_training_start(self, model, state)

    def get_total_training_steps() -> int:
        dataloader = self._train_dataloader(train_dataset)
        return len(dataloader) * self._max_epochs

    def get_total_eval_steps() -> int:
        assert val_dataset is not None and self._val_dataloader is not None
        dataloader = self._val_dataloader(val_dataset)
        return len(dataloader)

    def new_epoch(epoch: int) -> None:
        assert state is not None
        logger.info(f"Starting epoch {epoch}/{self._max_epochs}")
        for callback in self._callbacks:
            callback.on_epoch_start(self, model, state, epoch)

    def finalize_epoch(epoch: int) -> None:
        assert state is not None
        for callback in self._callbacks:
            callback.on_epoch_end(self, model, state, epoch)

    def new_batch(epoch: int) -> None:
        assert state is not None
        for callback in self._callbacks:
            callback.on_batch_start(self, model, state, epoch)

    def finalize_batch(epoch: int, output: ModelOutputT) -> None:
        assert state is not None
        for callback in self._callbacks:
            callback.on_batch_end(self, model, state, epoch, output)

    def new_evaluators() -> list[IEvaluator[ModelInputT, ModelOutputT]]:
        assert state is not None
        return [callback.on_eval_start(self, model, state) for callback in self._callbacks]

    def update_metrics(
        evaluators: list[IEvaluator[ModelInputT, ModelOutputT]],
        inputs: ModelInputT,
        output: ModelOutputT,
    ) -> None:
        assert state is not None
        for evaluator in evaluators:
            evaluator.update(inputs, output)

    def compute_metrics(evaluators: list[IEvaluator[ModelInputT, ModelOutputT]]) -> dict[str, float]:
        assert state is not None
        metrics = {}
        for evaluator in evaluators:
            metrics.update(evaluator.compute())
        return metrics

    def finalize_evaluation(metrics: Mapping[str, float], prefix: str) -> None:
        assert state is not None
        for callback in self._callbacks:
            callback.on_eval_end(self, model, state, metrics, prefix)

    def log(metrics: Mapping[str, float], prefix: str) -> None:
        assert state is not None
        if not metrics:
            return
        logger.info("%s", ", ".join(f"{prefix}{k}={v:.4f}" for k, v in metrics.items()))
        for callback in self._callbacks:
            callback.on_log(self, model, state, metrics, prefix=prefix)

    def do_evaluation(progress: Progress) -> None:
        if not val_dataset:
            return

        assert state is not None
        assert self._val_dataloader is not None

        evaluators = new_evaluators()

        task = progress.add_task("Evaluation", total=get_total_eval_steps())
        with closing(self._val_dataloader(val_dataset)) as val_dataloader:
            for batch in val_dataloader:
                output = eval_step(batch, state, self)
                update_metrics(evaluators, batch, output)
                progress.advance(task)
        progress.remove_task(task)

        computed_metrics = compute_metrics(evaluators)
        log(computed_metrics, prefix=self._val_prefix)
        finalize_evaluation(computed_metrics, prefix=self._val_prefix)

    def is_logging_step(step: int) -> bool:
        return (self._logging_strategy == "step" and step % self._logging_interval == 0) or (
            self._logging_first_step and step == 1
        )

    def is_logging_epoch(epoch: int) -> bool:
        return self._logging_strategy == "epoch" and epoch % self._logging_interval == 0

    def is_eval_step(step: int) -> bool:
        return self._eval_strategy == "step" and step % self._eval_interval == 0

    def is_eval_eopch(epoch: int) -> bool:
        return self._eval_strategy == "epoch" and epoch % self._eval_interval == 0

    evaluators = new_evaluators()

    try:
        with Progress(
            SpinnerColumn(),
            TextColumn("{task.description}"),
            BarColumn(),
            MofNCompleteColumn(),
            TimeRemainingColumn(),
            console=STDERR_CONSOLE,
        ) as progress:
            task = progress.add_task("Training", total=get_total_training_steps())
            for epoch in range(1, self._max_epochs + 1):
                assert state is not None
                new_epoch(epoch)

                with closing(self._train_dataloader(train_dataset)) as train_dataloader:
                    for batch in train_dataloader:
                        new_batch(epoch)

                        sharded_batch = self._distributor.shard(batch)
                        replicated_state = self._distributor.replicate(state)

                        replicated_state, replicated_output = train_step(sharded_batch, replicated_state, self)

                        state = self._distributor.unreplicate(replicated_state)
                        output = self._distributor.unreplicate(replicated_output)
                        assert state is not None

                        update_metrics(evaluators, batch, output)

                        if is_logging_step(int(state.step)):
                            train_metrics = compute_metrics(evaluators)
                            log(train_metrics, prefix=self._train_prefix)
                            finalize_evaluation(train_metrics, prefix=self._train_prefix)
                            evaluators = new_evaluators()

                        finalize_batch(epoch, output)

                        progress.advance(task)

                        if is_eval_step(int(state.step)):
                            do_evaluation(progress)

                if is_logging_epoch(epoch):
                    train_metrics = compute_metrics(evaluators)
                    log(train_metrics, prefix=self._train_prefix)
                    finalize_evaluation(train_metrics, prefix=self._train_prefix)
                    evaluators = new_evaluators()

                if is_eval_eopch(epoch):
                    do_evaluation(progress)

                finalize_epoch(epoch)
    except StopEarly:
        assert state is not None
        logger.info(f"Training stopped early at {state.step} steps.")

    for callback in self._callbacks:
        state = callback.on_training_end(self, model, state)

    return state

formed.integrations.flax.random

use_rngs

use_rngs(default=None)

Context manager to set and restore nnx RNGs.

This context manager allows temporarily setting the nnx RNGs used in Flax/nnx operations. It saves the current RNGs on entry and restores them on exit.

YIELDS DESCRIPTION
Rngs

The current nnx RNGs within the context.

Source code in src/formed/integrations/flax/random.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
@contextmanager
def use_rngs(default: int | None = None) -> Iterator[nnx.Rngs]:
    """Context manager to set and restore nnx RNGs.

    This context manager allows temporarily setting the nnx RNGs
    used in Flax/nnx operations. It saves the current RNGs on entry
    and restores them on exit.

    Yields:
        The current nnx RNGs within the context.

    """
    token = _NNX_RNGS.set(nnx.Rngs(default))
    try:
        yield _NNX_RNGS.get()
    finally:
        _NNX_RNGS.reset(token)

require_rngs

require_rngs()

Get the current nnx RNGs.

RETURNS DESCRIPTION
Rngs

The current nnx RNGs.

Source code in src/formed/integrations/flax/random.py
29
30
31
32
33
34
35
36
def require_rngs() -> nnx.Rngs:
    """Get the current nnx RNGs.

    Returns:
        The current nnx RNGs.

    """
    return _NNX_RNGS.get()

formed.integrations.flax.utils

PoolingMethod module-attribute

PoolingMethod = Literal[
    "mean", "max", "min", "sum", "hier", "first", "last"
]

ensure_jax_array

ensure_jax_array(x)
Source code in src/formed/integrations/flax/utils.py
12
13
14
15
def ensure_jax_array(x: ArrayCompatible) -> jax.Array:
    if isinstance(x, jax.Array):
        return x
    return jax.numpy.asarray(x)

masked_pool

masked_pool(
    embeddings,
    mask=None,
    pooling="mean",
    normalize=False,
    window_size=None,
)

Pool embeddings with a mask.

PARAMETER DESCRIPTION
embeddings

Embeddings to pool of shape (batch_size, sequence_length, embedding_size).

TYPE: Array

mask

Mask of shape (batch_size, sequence_length).

TYPE: Array | None DEFAULT: None

pooling

TYPE: PoolingMethod | Sequence[PoolingMethod] DEFAULT: 'mean'

normalize

Whether to normalize the embeddings before pooling. Defaults to False.

TYPE: bool DEFAULT: False

window_size

Window size for hierarchical pooling. Defaults to None.

TYPE: int | None DEFAULT: None

Source code in src/formed/integrations/flax/utils.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def masked_pool(
    embeddings: jax.Array,
    mask: jax.Array | None = None,
    pooling: PoolingMethod | Sequence[PoolingMethod] = "mean",
    normalize: bool = False,
    window_size: int | None = None,
) -> jax.Array:
    """
    Pool embeddings with a mask.

    Args:
        embeddings: Embeddings to pool of shape (batch_size, sequence_length, embedding_size).
        mask: Mask of shape (batch_size, sequence_length).
        pooling:
        normalize: Whether to normalize the embeddings before pooling. Defaults to `False`.
        window_size: Window size for hierarchical pooling. Defaults to `None`.
    """

    if not isinstance(pooling, str):
        return jax.numpy.concatenate(
            [
                masked_pool(
                    embeddings,
                    mask=mask,
                    pooling=method,
                    normalize=normalize,
                    window_size=window_size,
                )
                for method in pooling
            ],
            axis=-1,
        )

    batch_size, sequence_length, embedding_size = embeddings.shape

    if normalize:
        embeddings = embeddings / (jax.numpy.linalg.norm(embeddings, axis=-1, keepdims=True) + 1e-13)

    if mask is None:
        mask = jax.numpy.ones((batch_size, sequence_length), dtype=bool)

    if pooling == "mean":
        return embeddings.sum(axis=1) / (mask.sum(axis=1, keepdims=True) + 1e-13)

    if pooling == "max":
        embeddings = jax.numpy.where(mask[..., None], embeddings, -jax.numpy.inf)
        return embeddings.max(axis=1)

    if pooling == "min":
        embeddings = jax.numpy.where(mask[..., None], embeddings, jax.numpy.inf)
        return embeddings.min(axis=1)

    if pooling == "sum":
        embeddings = jax.numpy.where(mask[..., None], embeddings, 0)
        return embeddings.sum(axis=1)

    if pooling == "first":
        return embeddings[:, 0, :]

    if pooling == "last":
        batch_indices = jax.numpy.arange(batch_size)
        last_positions = mask.cumsum(axis=1).argmax(axis=1)
        return embeddings[batch_indices, last_positions, :]

    if pooling == "hier":

        def _hierarchical_pooling(vectors: jax.Array, mask: jax.Array) -> jax.Array:
            assert window_size is not None
            vectors = vectors[mask]
            if len(vectors) < window_size:
                return vectors.mean(0)
            output: jax.Array = -jax.numpy.inf * jax.numpy.ones(embedding_size)
            for offset in range(len(vectors) - window_size + 1):
                window = vectors[offset : offset + window_size]
                output = jax.numpy.maximum(output, window.mean(0))
            return output

        output: jax.Array = jax.numpy.array(list(starmap(_hierarchical_pooling, zip(embeddings, mask))))
        return output

    raise ValueError(
        f"pooling must be one of 'mean', 'max', 'min', 'sum', 'hier', 'first', or 'last', but got {pooling}"
    )

sequence_distribute

sequence_distribute(
    inputs: Array,
) -> tuple[Array, tuple[int, int]]
sequence_distribute(
    inputs: _MappingT, ignore: Sequence[str] = ...
) -> tuple[_MappingT, tuple[int, int]]
sequence_distribute(inputs, ignore=())
Source code in src/formed/integrations/flax/utils.py
127
128
129
130
131
132
133
134
135
136
137
138
139
def sequence_distribute(
    inputs: jax.Array | _MappingT,
    ignore: Sequence[str] = (),
) -> tuple[jax.Array | _MappingT, tuple[int, int]]:
    if isinstance(inputs, jax.Array):
        if inputs.ndim < 2:
            return inputs, (-1, -1)
        batch_size, max_length = inputs.shape[:2]
        return inputs.reshape((batch_size * max_length, *inputs.shape[2:])), (batch_size, max_length)
    distributed = [(key, sequence_distribute(value)) for key, value in inputs.items() if key not in ignore]
    arrays = {key: value[0] for key, value in distributed}
    shape = next(s for _, (_, s) in distributed if s != (-1, -1))
    return cast(_MappingT, arrays), shape

sequence_undistribute

sequence_undistribute(
    inputs: Array,
    shape: tuple[int, int],
    ignore: Sequence[str] = ...,
) -> Array
sequence_undistribute(
    inputs: _MappingT,
    shape: tuple[int, int],
    ignore: Sequence[str] = ...,
) -> _MappingT
sequence_undistribute(inputs, shape, ignore=())
Source code in src/formed/integrations/flax/utils.py
158
159
160
161
162
163
164
165
166
167
168
def sequence_undistribute(
    inputs: jax.Array | _MappingT,
    shape: tuple[int, int],
    ignore: Sequence[str] = (),
) -> jax.Array | _MappingT:
    if isinstance(inputs, jax.Array):
        return inputs.reshape((shape[0], shape[1], *inputs.shape[1:]))
    return cast(
        _MappingT,
        {key: sequence_undistribute(value, shape) for key, value in inputs.items() if key not in ignore},
    )

determine_ndim

determine_ndim(first, *args)
Source code in src/formed/integrations/flax/utils.py
171
172
173
174
175
176
177
178
179
180
181
182
183
def determine_ndim(
    first: int,
    *args: int | Callable[[int], int] | None,
) -> int:
    output_dim = first
    for arg in args:
        if arg is None:
            continue
        if callable(arg):
            output_dim = arg(output_dim)
        else:
            output_dim = arg
    return output_dim

formed.integrations.flax.workflow

Workflow integration for Flax model training.

This module provides workflow steps for training Flax models, allowing them to be integrated into the formed workflow system with automatic caching and dependency tracking.

Available Steps
  • flax::train: Train a Flax model using the provided trainer.
  • flax::evaluate: Evaluate a Flax model on a dataset.

Examples:

>>> from formed.integrations.flax import train_flax_model
>>>
>>> # In workflow configuration (jsonnet):
>>> # {
>>> #   steps: {
>>> #     train: {
>>> #       type: "flax::train",
>>> #       model: { type: "my_model", ... },
>>> #       trainer: { type: "flax_trainer", ... },
>>> #       train_dataset: { type: "ref", ref: "preprocess" },
>>> #       random_seed: 42
>>> #     }
>>> #   }
>>> # }

FlaxModelFormat

Bases: Format[BaseFlaxModel]

identifier property

identifier

Get the unique identifier for this format.

RETURNS DESCRIPTION
str

Format identifier string.

write

write(artifact, directory)
Source code in src/formed/integrations/flax/workflow.py
66
67
68
69
70
71
72
73
74
75
def write(self, artifact: BaseFlaxModel, directory: Path) -> None:
    if (config := getattr(artifact, "__model_config__", None)) is not None:
        config = dict(artifact.__model_config__)
        config[COLT_TYPEKEY] = f"{artifact.__class__.__module__}:{artifact.__class__.__name__}"
        del artifact.__model_config__
        self._get_config_path(directory).write_text(json.dumps(config, cls=WorkflowJSONEncoder))
        self._get_checkpointer(directory).save(0, artifact)
    else:
        with self._get_pickle_path(directory).open("wb") as f:
            cloudpickle.dump(artifact, f)

read

read(directory)
Source code in src/formed/integrations/flax/workflow.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def read(self, directory: Path) -> BaseFlaxModel:
    if (pickle_path := self._get_pickle_path(directory)).exists():
        with pickle_path.open("rb") as f:
            return cloudpickle.load(f)

    with use_rngs(0):
        model = COLT_BUILDER(
            json.loads(
                self._get_config_path(directory).read_text(),
                cls=WorkflowJSONDecoder,
            )
        )

    return cast(BaseFlaxModel, self._get_checkpointer(directory).restore(0, items=model))

is_default_of classmethod

is_default_of(obj)

Check if this format is the default for the given object type.

PARAMETER DESCRIPTION
obj

Object to check.

TYPE: Any

RETURNS DESCRIPTION
bool

True if this format should be used by default for this type.

Source code in src/formed/workflow/format.py
101
102
103
104
105
106
107
108
109
110
111
112
@classmethod
def is_default_of(cls, obj: Any) -> bool:
    """Check if this format is the default for the given object type.

    Args:
        obj: Object to check.

    Returns:
        True if this format should be used by default for this type.

    """
    return False

train_flax_model

train_flax_model(
    model,
    trainer,
    train_dataset,
    val_dataset=None,
    random_seed=0,
)

Train a Flax model using the provided trainer.

This workflow step trains a Flax NNX model on the provided datasets, returning the trained model. The training process is cached based on the model architecture, trainer configuration, and dataset fingerprints.

PARAMETER DESCRIPTION
model

Flax model to train.

TYPE: Lazy[BaseFlaxModel]

trainer

Trainer configuration with dataloaders and callbacks.

TYPE: FlaxTrainer

train_dataset

Training dataset items.

TYPE: Sequence[ItemT]

val_dataset

Optional validation dataset items.

TYPE: Sequence[ItemT] | None DEFAULT: None

random_seed

Random seed for reproducibility.

TYPE: int DEFAULT: 0

RETURNS DESCRIPTION
BaseFlaxModel

Trained Flax model with updated parameters.

Examples:

>>> # Use in Python code
>>> trained_model = train_flax_model(
...     model=my_model,
...     trainer=trainer,
...     train_dataset=train_data,
...     val_dataset=val_data,
...     random_seed=42
... )
Source code in src/formed/integrations/flax/workflow.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
@step("flax::train", format=FlaxModelFormat())
def train_flax_model(
    model: Lazy[BaseFlaxModel],
    trainer: FlaxTrainer,
    train_dataset: Sequence[ItemT],
    val_dataset: Sequence[ItemT] | None = None,
    random_seed: int = 0,
) -> BaseFlaxModel:
    """Train a Flax model using the provided trainer.

    This workflow step trains a Flax NNX model on the provided datasets,
    returning the trained model. The training process is cached based on
    the model architecture, trainer configuration, and dataset fingerprints.

    Args:
        model: Flax model to train.
        trainer: Trainer configuration with dataloaders and callbacks.
        train_dataset: Training dataset items.
        val_dataset: Optional validation dataset items.
        random_seed: Random seed for reproducibility.

    Returns:
        Trained Flax model with updated parameters.

    Examples:
        >>> # Use in Python code
        >>> trained_model = train_flax_model(
        ...     model=my_model,
        ...     trainer=trainer,
        ...     train_dataset=train_data,
        ...     val_dataset=val_data,
        ...     random_seed=42
        ... )

    """

    with use_rngs(random_seed):
        model_instance = model.construct()
        state = trainer.train(model_instance, train_dataset, val_dataset)

    model_instance = nnx.merge(state.graphdef, state.params, *state.additional_states)
    model_instance.__model_config__ = model.config
    return model_instance

evaluate_flax_model

evaluate_flax_model(
    model,
    evaluator,
    dataset,
    dataloader,
    params=None,
    random_seed=None,
)

Evaluate a Flax model on a dataset using the provided evaluator.

PARAMETER DESCRIPTION
model

Flax model to evaluate.

TYPE: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT]

evaluator

Evaluator to compute metrics.

TYPE: IEvaluator[ModelInputT, ModelOutputT]

dataset

Dataset items for evaluation.

TYPE: list[ItemT]

dataloader

DataLoader to convert items to model inputs.

TYPE: IDataLoader[ItemT, ModelInputT]

params

Optional model parameters to use for evaluation.

TYPE: ModelParamsT | None DEFAULT: None

RETURNS DESCRIPTION
dict[str, float]

Dictionary of computed evaluation metrics.

Source code in src/formed/integrations/flax/workflow.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
@step("flax::evaluate", format="json")
def evaluate_flax_model(
    model: BaseFlaxModel[ModelInputT, ModelOutputT, ModelParamsT],
    evaluator: IEvaluator[ModelInputT, ModelOutputT],
    dataset: list[ItemT],
    dataloader: IDataLoader[ItemT, ModelInputT],
    params: ModelParamsT | None = None,
    random_seed: int | None = None,
) -> Annotated[dict[str, float], WorkflowStepResultFlag.METRICS]:
    """Evaluate a Flax model on a dataset using the provided evaluator.

    Args:
        model: Flax model to evaluate.
        evaluator: Evaluator to compute metrics.
        dataset: Dataset items for evaluation.
        dataloader: DataLoader to convert items to model inputs.
        params: Optional model parameters to use for evaluation.

    Returns:
        Dictionary of computed evaluation metrics.
    """

    logger = use_step_logger(__name__)

    with use_rngs(random_seed):
        model.eval()
        evaluator.reset()

        with (
            closing(dataloader(dataset)) as loader,
            progress(loader, desc="Evaluating model") as iterator,
        ):
            for inputs in iterator:
                output = model(inputs, params)
                evaluator.update(inputs, output)

        metrics = evaluator.compute()
        logger.info("Evaluation metrics: %s", ", ".join(f"{k}={v:.4f}" for k, v in metrics.items()))

    return metrics