Flax¶
- Distributors
- Model
- Modules
- Training
- Random
- Utils
- Workflow
formed.integrations.flax.distributors
¶
Distributed computing abstractions for Flax models.
This module provides abstractions for distributed training across multiple devices, supporting both single-device and data-parallel training strategies.
Key Components
BaseDistributor: Abstract interface for device distribution strategiesSingleDeviceDistributor: No-op distributor for single-device trainingDataParallelDistributor: Data-parallel training using JAX pmap
Features
- Transparent device sharding and replication
- JIT compilation with device-specific optimizations
- Reduction operations (mean, sum) across devices
- Compatible with FlaxTrainer
Examples:
>>> from formed.integrations.flax import DataParallelDistributor
>>> import jax
>>>
>>> # Create data-parallel distributor for all available devices
>>> distributor = DataParallelDistributor(axis_name="batch")
>>>
>>> # Shard batch across devices
>>> sharded_batch = distributor.shard(batch)
>>>
>>> # Map function across devices
>>> train_step = distributor.map(training_function)
>>> outputs = train_step(sharded_batch, state)
BaseDistributor
¶
Bases: Registrable, ABC, Generic[ModelInputT]
Abstract base class for device distribution strategies.
BaseDistributor defines the interface for distributing computations across devices in a JAX/Flax training pipeline. It handles data sharding, replication, and reduction operations.
| CLASS TYPE PARAMETER | DESCRIPTION |
|---|---|
ModelInputT
|
Type of model input data.
|
shard
¶
shard(inputs)
Shard inputs across devices.
| PARAMETER | DESCRIPTION |
|---|---|
inputs
|
Input data to shard.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
ModelInputT
|
Sharded input data with an additional device dimension. |
Source code in src/formed/integrations/flax/distributors.py
62 63 64 65 66 67 68 69 70 71 72 | |
replicate
¶
replicate(inputs)
Replicate data across all devices.
| PARAMETER | DESCRIPTION |
|---|---|
inputs
|
Data to replicate.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
_T
|
Replicated data with device dimension. |
Source code in src/formed/integrations/flax/distributors.py
74 75 76 77 78 79 80 81 82 83 84 | |
unreplicate
¶
unreplicate(inputs)
Extract data from the first device, removing device dimension.
| PARAMETER | DESCRIPTION |
|---|---|
inputs
|
Replicated data with device dimension.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
_T
|
Data from first device without device dimension. |
Source code in src/formed/integrations/flax/distributors.py
86 87 88 89 90 91 92 93 94 95 96 | |
map
abstractmethod
¶
map(fn, static_argnums=())
Map a function across devices with JIT compilation.
| PARAMETER | DESCRIPTION |
|---|---|
fn
|
Function to map across devices.
TYPE:
|
static_argnums
|
Indices of static arguments.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
_CallableT
|
Mapped and compiled function. |
Source code in src/formed/integrations/flax/distributors.py
98 99 100 101 102 103 104 105 106 107 108 109 110 | |
reduce
abstractmethod
¶
reduce(array, op='mean')
Reduce an array across devices.
| PARAMETER | DESCRIPTION |
|---|---|
array
|
Array to reduce.
TYPE:
|
op
|
Reduction operation (
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
_ArrayT
|
Reduced array. |
Source code in src/formed/integrations/flax/distributors.py
112 113 114 115 116 117 118 119 120 121 122 123 124 | |
SingleDeviceDistributor
¶
Bases: BaseDistributor[ModelInputT]
Distributor for single-device training.
This distributor applies JIT compilation without any device distribution. All shard, replicate, and unreplicate operations are no-ops.
Examples:
>>> distributor = SingleDeviceDistributor()
>>> train_step = distributor.map(my_train_function)
>>> output = train_step(batch, state, trainer)
map
¶
map(fn, static_argnums=())
Apply JIT compilation to a function.
| PARAMETER | DESCRIPTION |
|---|---|
fn
|
Function to compile.
TYPE:
|
static_argnums
|
Indices of static arguments.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
_CallableT
|
JIT-compiled function. |
Source code in src/formed/integrations/flax/distributors.py
141 142 143 144 145 146 147 148 149 150 151 152 | |
reduce
¶
reduce(array, op='mean')
Return array unchanged (no reduction needed for single device).
| PARAMETER | DESCRIPTION |
|---|---|
array
|
Input array.
TYPE:
|
op
|
Reduction operation (ignored).
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
_ArrayT
|
Input array unchanged. |
Source code in src/formed/integrations/flax/distributors.py
154 155 156 157 158 159 160 161 162 163 164 165 | |
shard
¶
shard(inputs)
Shard inputs across devices.
| PARAMETER | DESCRIPTION |
|---|---|
inputs
|
Input data to shard.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
ModelInputT
|
Sharded input data with an additional device dimension. |
Source code in src/formed/integrations/flax/distributors.py
62 63 64 65 66 67 68 69 70 71 72 | |
replicate
¶
replicate(inputs)
Replicate data across all devices.
| PARAMETER | DESCRIPTION |
|---|---|
inputs
|
Data to replicate.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
_T
|
Replicated data with device dimension. |
Source code in src/formed/integrations/flax/distributors.py
74 75 76 77 78 79 80 81 82 83 84 | |
unreplicate
¶
unreplicate(inputs)
Extract data from the first device, removing device dimension.
| PARAMETER | DESCRIPTION |
|---|---|
inputs
|
Replicated data with device dimension.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
_T
|
Data from first device without device dimension. |
Source code in src/formed/integrations/flax/distributors.py
86 87 88 89 90 91 92 93 94 95 96 | |
DataParallelDistributor
¶
DataParallelDistributor(
axis_name="batch", num_devices=None
)
Bases: BaseDistributor[ModelInputT]
Distributor for data-parallel training across multiple devices.
This distributor uses JAX's pmap to execute the same computation on different data shards across multiple devices. Data is automatically sharded along the batch dimension.
| PARAMETER | DESCRIPTION |
|---|---|
axis_name
|
Name for the device axis (used in reduction operations).
TYPE:
|
num_devices
|
Number of devices to use. Defaults to all local devices.
TYPE:
|
Examples:
>>> # Train on 4 GPUs with data parallelism
>>> distributor = DataParallelDistributor(
... axis_name="batch",
... num_devices=4
... )
>>>
>>> # Shard batch of size 32 into 4 shards of size 8
>>> sharded = distributor.shard(batch)
>>> assert sharded.shape == (4, 8, ...)
>>>
>>> # Map training step across devices
>>> train_step = distributor.map(my_train_step)
Note
Batch size must be divisible by num_devices for proper sharding.
Source code in src/formed/integrations/flax/distributors.py
199 200 201 202 203 204 205 | |
shard
¶
shard(inputs)
Shard inputs along the batch dimension across devices.
| PARAMETER | DESCRIPTION |
|---|---|
inputs
|
Input data with batch dimension.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
ModelInputT
|
Sharded inputs with shape (num_devices, batch_per_device, ...). |
Source code in src/formed/integrations/flax/distributors.py
207 208 209 210 211 212 213 214 215 216 217 | |
replicate
¶
replicate(inputs)
Replicate data across all devices.
| PARAMETER | DESCRIPTION |
|---|---|
inputs
|
Data to replicate.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
_T
|
Replicated data with device dimension. |
Source code in src/formed/integrations/flax/distributors.py
219 220 221 222 223 224 225 226 227 228 229 | |
unreplicate
¶
unreplicate(inputs)
Extract data from the first device.
| PARAMETER | DESCRIPTION |
|---|---|
inputs
|
Replicated data.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
_T
|
Data from first device. |
Source code in src/formed/integrations/flax/distributors.py
231 232 233 234 235 236 237 238 239 240 241 | |
map
¶
map(fn, static_argnums=())
Map function across devices using pmap.
| PARAMETER | DESCRIPTION |
|---|---|
fn
|
Function to parallelize.
TYPE:
|
static_argnums
|
Indices of static arguments to broadcast.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
_CallableT
|
Parallelized function using pmap. |
Source code in src/formed/integrations/flax/distributors.py
243 244 245 246 247 248 249 250 251 252 253 254 | |
reduce
¶
reduce(array, op='mean')
Reduce array across devices.
| PARAMETER | DESCRIPTION |
|---|---|
array
|
Array to reduce across device dimension.
TYPE:
|
op
|
Reduction operation -
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
_ArrayT
|
Reduced array. |
| RAISES | DESCRIPTION |
|---|---|
ValueError
|
If unsupported reduction operation is specified. |
Source code in src/formed/integrations/flax/distributors.py
256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 | |
formed.integrations.flax.model
¶
Base model abstraction for Flax NNX models.
This module provides the base class for all Flax models in the framework, integrating Flax NNX with the registrable pattern for configuration-based model instantiation.
Key Features
- Integration with Flax NNX Module system
- Registrable pattern for configuration-based instantiation
- Generic type support for inputs, outputs, and parameters
- Compatible with FlaxTrainer for end-to-end training
Examples:
>>> from formed.integrations.flax import BaseFlaxModel
>>> from flax import nnx
>>> import jax
>>>
>>> @BaseFlaxModel.register("my_model")
... class MyModel(BaseFlaxModel[dict, jax.Array, None]):
... def __init__(self, rngs: nnx.Rngs, hidden_dim: int):
... self.linear = nnx.Linear(10, hidden_dim, rngs=rngs)
...
... def __call__(self, inputs: dict, params: None = None) -> jax.Array:
... return self.linear(inputs["features"])
BaseFlaxModel
¶
Bases: Module, Registrable, Generic[ModelInputT, ModelOutputT, ModelParamsT]
Base class for all Flax NNX models in the framework.
This class combines Flax's NNX Module with the registrable pattern, allowing models to be instantiated from configuration files and seamlessly integrated with the training infrastructure.
| CLASS TYPE PARAMETER | DESCRIPTION |
|---|---|
ModelInputT
|
Type of input data to the model.
|
ModelOutputT
|
Type of model output.
|
ModelParamsT
|
Type of additional parameters (typically
|
Note
Subclasses should implement __call__ to define the forward pass.
Models are automatically compatible with FlaxTrainer when registered.
formed.integrations.flax.modules.embedders
¶
Text embedding modules for Flax models.
This module provides embedders that convert tokenized text into dense vector representations. Embedders handle various text representations including surface forms, part-of-speech tags, and character sequences.
Key Components
- BaseEmbedder: Abstract base class for all embedders
- TokenEmbedder: Embeds token ID sequences into dense vectors
- AnalyzedTextEmbedder: Combines multiple embedding types (surface, POS, chars)
Features
- Support for nested token sequences (e.g., word -> character)
- Automatic masking and padding handling
- Configurable vectorization for character-level embeddings
- Concatenation of multiple embedding types
Examples:
>>> from formed.integrations.flax.modules import TokenEmbedder, AnalyzedTextEmbedder
>>> from flax import nnx
>>>
>>> # Simple token embedder
>>> embedder = TokenEmbedder(
... vocab_size=10000,
... embedding_dim=128,
... rngs=nnx.Rngs(0)
... )
>>>
>>> # Multi-feature embedder
>>> embedder = AnalyzedTextEmbedder(
... surface=TokenEmbedder(vocab_size=10000, embedding_dim=128, rngs=rngs),
... postag=TokenEmbedder(vocab_size=50, embedding_dim=32, rngs=rngs)
... )
EmbedderOutput
¶
Bases: NamedTuple
Output from an embedder.
| ATTRIBUTE | DESCRIPTION |
|---|---|
embeddings |
Dense embeddings of shape (batch_size, seq_len, embedding_dim).
TYPE:
|
mask |
Attention mask of shape (batch_size, seq_len).
TYPE:
|
BaseEmbedder
¶
Bases: Module, Registrable, Generic[_TextBatchT], ABC
Abstract base class for text embedders.
Embedders convert tokenized text into dense vector representations. They output both embeddings and attention masks.
| CLASS TYPE PARAMETER | DESCRIPTION |
|---|---|
_TextBatchT
|
Type of input batch (e.g., IIDSequenceBatch, IAnalyzedTextBatch).
|
get_output_dim
abstractmethod
¶
get_output_dim()
Get the output embedding dimension.
| RETURNS | DESCRIPTION |
|---|---|
int
|
Embedding dimension. |
Source code in src/formed/integrations/flax/modules/embedders.py
89 90 91 92 93 94 95 96 97 | |
TokenEmbedder
¶
TokenEmbedder(
vocab_size, embedding_dim, *, vectorizer=None, rngs=None
)
Bases: BaseEmbedder[IIDSequenceBatch]
Embedder for token ID sequences.
This embedder converts token IDs into dense embeddings using a learned embedding matrix. It supports both 2D (batch_size, seq_len) and 3D (batch_size, seq_len, char_len) token ID tensors.
For 3D inputs (e.g., character-level tokens within words), the embedder can either average the embeddings or apply a custom vectorizer.
| PARAMETER | DESCRIPTION |
|---|---|
vocab_size
|
Size of the vocabulary.
TYPE:
|
embedding_dim
|
Dimension of the embedding vectors.
TYPE:
|
vectorizer
|
Optional vectorizer for 3D inputs (character sequences).
TYPE:
|
rngs
|
Random number generators.
TYPE:
|
Examples:
>>> # Simple word embeddings
>>> embedder = TokenEmbedder(vocab_size=10000, embedding_dim=128, rngs=rngs)
>>> output = embedder(word_ids_batch)
>>>
>>> # Character-level embeddings with pooling
>>> from formed.integrations.flax.modules import BagOfEmbeddingsSequenceVectorizer
>>> embedder = TokenEmbedder(
... vocab_size=256,
... embedding_dim=32,
... vectorizer=BagOfEmbeddingsSequenceVectorizer(pooling="max"),
... rngs=rngs
... )
Source code in src/formed/integrations/flax/modules/embedders.py
133 134 135 136 137 138 139 140 141 142 143 | |
get_output_dim
¶
get_output_dim()
Source code in src/formed/integrations/flax/modules/embedders.py
175 176 | |
AnalyzedTextEmbedder
¶
AnalyzedTextEmbedder(
surface=None, postag=None, character=None
)
Bases: BaseEmbedder[IAnalyzedTextBatch]
Embedder for analyzed text with multiple linguistic features.
This embedder combines embeddings from multiple linguistic representations (surface forms, part-of-speech tags, character sequences) by concatenating them along the feature dimension.
| PARAMETER | DESCRIPTION |
|---|---|
surface
|
Optional embedder for surface form tokens.
TYPE:
|
postag
|
Optional embedder for part-of-speech tags.
TYPE:
|
character
|
Optional embedder for character sequences.
TYPE:
|
| RAISES | DESCRIPTION |
|---|---|
ValueError
|
If all embedders are None (at least one is required). |
Examples:
>>> from formed.integrations.flax.modules import (
... AnalyzedTextEmbedder,
... TokenEmbedder
... )
>>>
>>> embedder = AnalyzedTextEmbedder(
... surface=TokenEmbedder(vocab_size=10000, embedding_dim=128, rngs=rngs),
... postag=TokenEmbedder(vocab_size=50, embedding_dim=32, rngs=rngs),
... character=TokenEmbedder(vocab_size=256, embedding_dim=32, rngs=rngs)
... )
>>>
>>> # Output dimension is sum of all embedding dimensions (128 + 32 + 32 = 192)
>>> assert embedder.get_output_dim() == 192
Note
All provided embedders share the same mask, which is taken from the last non-None embedder processed.
Source code in src/formed/integrations/flax/modules/embedders.py
216 217 218 219 220 221 222 223 224 225 226 227 | |
get_output_dim
¶
get_output_dim()
Source code in src/formed/integrations/flax/modules/embedders.py
247 248 249 250 251 252 | |
formed.integrations.flax.modules.encoders
¶
Sequence encoding modules for Flax models.
This module provides encoders that process sequential data, including position encoders, RNN-based encoders, and transformer encoders.
Key Components
Position Encoders: - BasePositionEncoder: Abstract base for position encoding - SinusoidalPositionEncoder: Sinusoidal position embeddings - LearnablePositionEncoder: Learned position embeddings
Sequence Encoders: - BaseSequenceEncoder: Abstract base for sequence encoders - RNNSequenceEncoder: Generic RNN encoder (LSTM, GRU, vanilla RNN) - LSTMSequenceEncoder: LSTM-specific encoder - OptimizedLSTMSequenceEncoder: Optimized LSTM encoder - GRUSequenceEncoder: GRU-specific encoder - TransformerSequenceEncoder: Transformer encoder with multi-head attention
Features
- Bidirectional RNN support
- Stacked layers with dropout
- Position encoding for transformers
- Efficient implementation using scan and vmap
- Masked sequence processing
Examples:
>>> from formed.integrations.flax.modules import (
... LSTMSequenceEncoder,
... TransformerSequenceEncoder,
... SinusoidalPositionEncoder
... )
>>> from flax import nnx
>>>
>>> # Bidirectional LSTM encoder
>>> encoder = LSTMSequenceEncoder(
... features=128,
... num_layers=2,
... bidirectional=True,
... dropout=0.1,
... rngs=nnx.Rngs(0)
... )
>>>
>>> # Transformer encoder with sinusoidal positions
>>> encoder = TransformerSequenceEncoder(
... features=128,
... num_heads=8,
... num_layers=6,
... position_encoder=SinusoidalPositionEncoder(),
... rngs=rngs
... )
BasePositionEncoder
¶
Bases: Registrable, ABC
Abstract base class for position encoders.
Position encoders add positional information to input embeddings, allowing models to understand token positions in sequences.
SinusoidalPositionEncoder
¶
SinusoidalPositionEncoder(max_length=512)
Bases: BasePositionEncoder
Sinusoidal position encoding from "Attention Is All You Need".
This encoder uses sine and cosine functions of different frequencies to generate position embeddings without learnable parameters.
| PARAMETER | DESCRIPTION |
|---|---|
max_length
|
Maximum sequence length to support.
TYPE:
|
Examples:
>>> encoder = SinusoidalPositionEncoder(max_length=512)
>>> encoded = encoder(embeddings)
Note
Encodings are cached for efficiency.
Source code in src/formed/integrations/flax/modules/encoders.py
110 111 | |
LearnablePositionEncoder
¶
LearnablePositionEncoder(
features, *, max_length=512, rngs=None
)
Bases: BasePositionEncoder
Learnable position embeddings.
This encoder uses a learned embedding matrix for position encoding, allowing the model to learn task-specific positional patterns.
| PARAMETER | DESCRIPTION |
|---|---|
features
|
Embedding dimension.
TYPE:
|
rngs
|
Random number generators.
TYPE:
|
max_length
|
Maximum sequence length to support.
TYPE:
|
Examples:
>>> encoder = LearnablePositionEncoder(
... features=128,
... max_length=512,
... rngs=rngs
... )
Source code in src/formed/integrations/flax/modules/encoders.py
147 148 149 150 151 152 153 154 155 | |
BaseSequenceEncoder
¶
Bases: Module, Registrable, ABC
Abstract base class for sequence encoders.
Sequence encoders process sequential data and output contextual representations for each position in the sequence.
get_input_dim
abstractmethod
¶
get_input_dim()
Get the expected input dimension.
| RETURNS | DESCRIPTION |
|---|---|
int | None
|
Input feature dimension or None if dimension-agnostic. |
Source code in src/formed/integrations/flax/modules/encoders.py
190 191 192 193 194 195 196 197 198 | |
get_output_dim
abstractmethod
¶
get_output_dim()
Get the output dimension.
| RETURNS | DESCRIPTION |
|---|---|
int | Callable[[int], int]
|
Output feature dimension or a function mapping input dim to output dim. |
Source code in src/formed/integrations/flax/modules/encoders.py
200 201 202 203 204 205 206 207 208 | |
RNNSequenceEncoder
¶
RNNSequenceEncoder(
cell_factory,
features,
num_layers=1,
bidirectional=False,
feedforward_layers=None,
dropout=0.0,
rngs=None,
)
Bases: BaseSequenceEncoder
Generic RNN-based sequence encoder.
This encoder supports various RNN cell types (LSTM, GRU, vanilla RNN) with optional bidirectionality, multiple layers, and dropout.
| PARAMETER | DESCRIPTION |
|---|---|
cell_factory
|
Function that creates an RNN cell given rngs.
TYPE:
|
num_layers
|
Number of stacked RNN layers.
TYPE:
|
bidirectional
|
Whether to use bidirectional RNN.
TYPE:
|
dropout
|
Dropout rate between layers.
TYPE:
|
rngs
|
Random number generators (int or nnx.Rngs).
TYPE:
|
Note
This is a base class. Use LSTMSequenceEncoder, GRUSequenceEncoder, or OptimizedLSTMSequenceEncoder for specific cell types.
Source code in src/formed/integrations/flax/modules/encoders.py
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 | |
get_input_dim
¶
get_input_dim()
Source code in src/formed/integrations/flax/modules/encoders.py
338 339 | |
get_output_dim
¶
get_output_dim()
Source code in src/formed/integrations/flax/modules/encoders.py
341 342 | |
LSTMSequenceEncoder
¶
LSTMSequenceEncoder(
features,
num_layers=1,
bidirectional=False,
feedforward_layers=None,
dropout=0.0,
rngs=None,
)
Bases: RNNSequenceEncoder
LSTM-based sequence encoder.
| PARAMETER | DESCRIPTION |
|---|---|
features
|
Hidden dimension.
TYPE:
|
num_layers
|
Number of LSTM layers.
TYPE:
|
bidirectional
|
Whether to use bidirectional LSTM.
TYPE:
|
dropout
|
Dropout rate between layers.
TYPE:
|
rngs
|
Random number generators.
TYPE:
|
Examples:
>>> # Bidirectional 2-layer LSTM
>>> encoder = LSTMSequenceEncoder(
... features=128,
... num_layers=2,
... bidirectional=True,
... dropout=0.1,
... rngs=0
... )
Source code in src/formed/integrations/flax/modules/encoders.py
368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 | |
get_input_dim
¶
get_input_dim()
Source code in src/formed/integrations/flax/modules/encoders.py
338 339 | |
get_output_dim
¶
get_output_dim()
Source code in src/formed/integrations/flax/modules/encoders.py
341 342 | |
OptimizedLSTMSequenceEncoder
¶
OptimizedLSTMSequenceEncoder(
features,
num_layers=1,
bidirectional=False,
feedforward_layers=None,
dropout=0.0,
rngs=None,
)
Bases: RNNSequenceEncoder
Optimized LSTM sequence encoder.
Uses Flax's optimized LSTM implementation for better performance.
| PARAMETER | DESCRIPTION |
|---|---|
features
|
Hidden dimension.
TYPE:
|
num_layers
|
Number of LSTM layers.
TYPE:
|
bidirectional
|
Whether to use bidirectional LSTM.
TYPE:
|
dropout
|
Dropout rate between layers.
TYPE:
|
rngs
|
Random number generators.
TYPE:
|
Source code in src/formed/integrations/flax/modules/encoders.py
404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 | |
get_input_dim
¶
get_input_dim()
Source code in src/formed/integrations/flax/modules/encoders.py
338 339 | |
get_output_dim
¶
get_output_dim()
Source code in src/formed/integrations/flax/modules/encoders.py
341 342 | |
GRUSequenceEncoder
¶
GRUSequenceEncoder(
features,
num_layers=1,
bidirectional=False,
feedforward_layers=None,
dropout=0.0,
rngs=None,
)
Bases: RNNSequenceEncoder
GRU-based sequence encoder.
| PARAMETER | DESCRIPTION |
|---|---|
features
|
Hidden dimension.
TYPE:
|
num_layers
|
Number of GRU layers.
TYPE:
|
bidirectional
|
Whether to use bidirectional GRU.
TYPE:
|
dropout
|
Dropout rate between layers.
TYPE:
|
rngs
|
Random number generators.
TYPE:
|
Examples:
>>> encoder = GRUSequenceEncoder(
... features=256,
... num_layers=3,
... bidirectional=True,
... rngs=0
... )
Source code in src/formed/integrations/flax/modules/encoders.py
445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 | |
get_input_dim
¶
get_input_dim()
Source code in src/formed/integrations/flax/modules/encoders.py
338 339 | |
get_output_dim
¶
get_output_dim()
Source code in src/formed/integrations/flax/modules/encoders.py
341 342 | |
TransformerSequenceEncoder
¶
TransformerSequenceEncoder(
features,
num_heads,
*,
num_layers=1,
dropout=0.0,
epsilon=1e-06,
feedworward_features=None,
activation=gelu,
position_encoder=None,
rngs=None,
)
Bases: BaseSequenceEncoder
Transformer-based sequence encoder.
This encoder uses multi-head self-attention and feed-forward layers to process sequences, following the Transformer architecture.
| PARAMETER | DESCRIPTION |
|---|---|
features
|
Model dimension.
TYPE:
|
num_heads
|
Number of attention heads.
TYPE:
|
num_layers
|
Number of transformer layers.
TYPE:
|
dropout
|
Dropout rate.
TYPE:
|
epsilon
|
Layer normalization epsilon.
TYPE:
|
feedworward_features
|
Feed-forward hidden dimension (defaults to 4*features).
TYPE:
|
activation
|
Activation function for feed-forward layers.
TYPE:
|
position_encoder
|
Optional position encoder.
TYPE:
|
rngs
|
Random number generators.
TYPE:
|
Examples:
>>> # Transformer with sinusoidal positions
>>> encoder = TransformerSequenceEncoder(
... features=512,
... num_heads=8,
... num_layers=6,
... dropout=0.1,
... position_encoder=SinusoidalPositionEncoder(),
... rngs=rngs
... )
>>>
>>> # Transformer with learnable positions
>>> encoder = TransformerSequenceEncoder(
... features=512,
... num_heads=8,
... num_layers=6,
... position_encoder=LearnablePositionEncoder(512, rngs=rngs),
... rngs=rngs
... )
Source code in src/formed/integrations/flax/modules/encoders.py
541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 | |
get_input_dim
¶
get_input_dim()
Source code in src/formed/integrations/flax/modules/encoders.py
601 602 | |
get_output_dim
¶
get_output_dim()
Source code in src/formed/integrations/flax/modules/encoders.py
604 605 | |
formed.integrations.flax.modules.feedforward
¶
Feed-forward neural network modules for Flax models.
This module provides feed-forward network layers with support for multiple layers, dropout, layer normalization, and residual connections.
Key Components
- Block: Single feed-forward block with optional normalization and dropout
- FeedForward: Stacked feed-forward blocks with configurable connections
Features
- Configurable activation functions
- Layer normalization with custom epsilon
- Dropout for regularization
- Dense residual connections
- Efficient implementation using scan
Examples:
>>> from formed.integrations.flax.modules import FeedForward
>>> from flax import nnx
>>> import jax.nn
>>>
>>> # Simple 3-layer feed-forward network
>>> ffn = FeedForward(
... features=128,
... num_layers=3,
... dropout=0.1,
... activation=jax.nn.gelu,
... rngs=nnx.Rngs(0)
... )
>>>
>>> # With dense residual connections
>>> ffn = FeedForward(
... features=128,
... num_layers=3,
... residual_connection="dense",
... rngs=rngs
... )
Block
¶
Block(
input_dim,
output_dim,
dropout=0.0,
layer_norm_eps=None,
activation=relu,
rngs=None,
)
Bases: Module
Single feed-forward block with optional normalization and dropout.
A block consists of a linear transformation, activation, optional dropout, and optional layer normalization. It can also accept a residual input.
| PARAMETER | DESCRIPTION |
|---|---|
rngs
|
Random number generators.
TYPE:
|
input_dim
|
Input dimension.
TYPE:
|
output_dim
|
Output dimension.
TYPE:
|
dropout
|
Dropout rate (0 means no dropout).
TYPE:
|
layer_norm_eps
|
Layer normalization epsilon (None means no layer norm).
TYPE:
|
activation
|
Activation function (default: ReLU).
TYPE:
|
Source code in src/formed/integrations/flax/modules/feedforward.py
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | |
layer_norm
instance-attribute
¶
layer_norm = (
LayerNorm(output_dim, epsilon=layer_norm_eps, rngs=rngs)
if layer_norm_eps is not None
else None
)
FeedForward
¶
FeedForward(
features,
num_layers=1,
dropout=0.0,
layer_norm_eps=None,
activation=relu,
residual_connection="none",
rngs=None,
)
Bases: Module
Multi-layer feed-forward neural network.
This module stacks multiple feed-forward blocks with configurable activation, dropout, normalization, and residual connections.
| PARAMETER | DESCRIPTION |
|---|---|
features
|
Hidden dimension for all layers.
TYPE:
|
num_layers
|
Number of layers.
TYPE:
|
dropout
|
Dropout rate applied after each activation.
TYPE:
|
layer_norm_eps
|
Epsilon for layer normalization (None disables layer norm).
TYPE:
|
activation
|
Activation function (default: ReLU).
TYPE:
|
residual_connection
|
Type of residual connection: - "none": No residual connections (default) - "dense": Dense connections (each layer receives sum of all previous)
TYPE:
|
rngs
|
Random number generators (can be int seed or nnx.Rngs).
TYPE:
|
Examples:
>>> # Simple 3-layer network
>>> ffn = FeedForward(features=256, num_layers=3, rngs=0)
>>> output = ffn(x)
>>>
>>> # With dropout and layer norm
>>> ffn = FeedForward(
... features=256,
... num_layers=3,
... dropout=0.1,
... layer_norm_eps=1e-6,
... rngs=0
... )
>>>
>>> # With dense residual connections
>>> ffn = FeedForward(
... features=256,
... num_layers=4,
... residual_connection="dense",
... rngs=0
... )
Note
When using residual_connection="dense", each layer receives the sum of outputs from all previous layers, similar to DenseNet.
Source code in src/formed/integrations/flax/modules/feedforward.py
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | |
get_input_dim
¶
get_input_dim()
Source code in src/formed/integrations/flax/modules/feedforward.py
167 168 | |
get_output_dim
¶
get_output_dim()
Source code in src/formed/integrations/flax/modules/feedforward.py
170 171 | |
formed.integrations.flax.modules.losses
¶
BaseClassificationLoss
¶
Bases: Module, Registrable, Generic[_ParamsT], ABC
Abstract base class for classification loss functions.
A ClassificationLoss defines a strategy for computing loss based on model logits and true labels.
| CLASS TYPE PARAMETER | DESCRIPTION |
|---|---|
_ParamsT
|
Type of additional parameters used during loss computation.
|
CrossEntropyLoss
¶
CrossEntropyLoss(weighter=None, reduce='mean')
Bases: BaseClassificationLoss[_ParamsT]
Cross-entropy loss for classification tasks.
| PARAMETER | DESCRIPTION |
|---|---|
weighter
|
An optional label weighter to assign weights to each class.
TYPE:
|
reduce
|
Reduction method for loss computation ("mean" or "sum").
TYPE:
|
Source code in src/formed/integrations/flax/modules/losses.py
50 51 52 53 54 55 56 57 | |
formed.integrations.flax.modules.samplers
¶
BaseLabelSampler
¶
Bases: Module, Registrable, Generic[_ParamsT], ABC
Abstract base class for label samplers.
A LabelSampler defines a strategy for sampling labels based on model logits.
| CLASS TYPE PARAMETER | DESCRIPTION |
|---|---|
_ParamsT
|
Type of additional parameters used during sampling.
|
ArgmaxLabelSampler
¶
MultinomialLabelSamplerParams
¶
Bases: NamedTuple
Parameters for the MultinomialLabelSampler.
This class can be extended in the future to include additional parameters for sampling if needed.
| ATTRIBUTE | DESCRIPTION |
|---|---|
rngs |
Random number generators.
TYPE:
|
templerature |
Sampling temperature to control randomness.
TYPE:
|
MultinomialLabelSampler
¶
Bases: BaseLabelSampler[MultinomialLabelSamplerParams]
Label sampler that samples labels from a multinomial distribution defined by the logits.
formed.integrations.flax.modules.vectorizers
¶
Sequence vectorization modules for Flax models.
This module provides vectorizers that convert variable-length sequences into fixed-size vectors. Vectorizers apply pooling operations over the sequence dimension to produce single vectors per sequence.
Key Components
- BaseSequenceVectorizer: Abstract base class for vectorizers
- BagOfEmbeddingsSequenceVectorizer: Pools sequence embeddings
Features
- Multiple pooling strategies (mean, max, min, sum, first, last, hier)
- Masked pooling to ignore padding tokens
- Optional normalization before pooling
- Hierarchical pooling with sliding windows
Examples:
>>> from formed.integrations.flax.modules import BagOfEmbeddingsSequenceVectorizer
>>>
>>> # Mean pooling over sequence
>>> vectorizer = BagOfEmbeddingsSequenceVectorizer(pooling="mean")
>>> vector = vectorizer(embeddings, mask=mask)
>>>
>>> # Max pooling with normalization
>>> vectorizer = BagOfEmbeddingsSequenceVectorizer(
... pooling="max",
... normalize=True
... )
BaseSequenceVectorizer
¶
Bases: Module, Registrable, ABC
Abstract base class for sequence vectorizers.
Vectorizers convert variable-length sequences into fixed-size vectors by applying pooling operations over the sequence dimension.
get_input_dim
abstractmethod
¶
get_input_dim()
Get the expected input dimension.
| RETURNS | DESCRIPTION |
|---|---|
int | None
|
Input dimension or None if dimension-agnostic. |
Source code in src/formed/integrations/flax/modules/vectorizers.py
69 70 71 72 73 74 75 76 77 | |
get_output_dim
abstractmethod
¶
get_output_dim()
Get the output dimension.
| RETURNS | DESCRIPTION |
|---|---|
int | Callable[[int], int]
|
Output feature dimension or a function mapping input dim to output dim. |
Source code in src/formed/integrations/flax/modules/vectorizers.py
79 80 81 82 83 84 85 86 87 | |
BagOfEmbeddingsSequenceVectorizer
¶
BagOfEmbeddingsSequenceVectorizer(
pooling="mean", normalize=False, window_size=None
)
Bases: BaseSequenceVectorizer
Bag-of-embeddings vectorizer using pooling operations.
This vectorizer applies pooling over the sequence dimension to create fixed-size vectors. Multiple pooling strategies are supported, and padding tokens are properly masked during pooling.
| PARAMETER | DESCRIPTION |
|---|---|
pooling
|
Pooling strategy to use: - "mean": Average pooling (default) - "max": Max pooling - "min": Min pooling - "sum": Sum pooling - "first": Take first token - "last": Take last non-padding token - "hier": Hierarchical pooling with sliding window
TYPE:
|
normalize
|
Whether to L2-normalize embeddings before pooling.
TYPE:
|
window_size
|
Window size for hierarchical pooling (required if pooling="hier").
TYPE:
|
Examples:
>>> # Mean pooling
>>> vectorizer = BagOfEmbeddingsSequenceVectorizer(pooling="mean")
>>> vector = vectorizer(embeddings, mask=mask)
>>>
>>> # Max pooling with normalization
>>> vectorizer = BagOfEmbeddingsSequenceVectorizer(
... pooling="max",
... normalize=True
... )
>>>
>>> # Hierarchical pooling
>>> vectorizer = BagOfEmbeddingsSequenceVectorizer(
... pooling="hier",
... window_size=3
... )
Note
This vectorizer is dimension-agnostic - it preserves the embedding dimension from input to output.
Source code in src/formed/integrations/flax/modules/vectorizers.py
134 135 136 137 138 139 140 141 142 | |
get_input_dim
¶
get_input_dim()
Source code in src/formed/integrations/flax/modules/vectorizers.py
158 159 | |
get_output_dim
¶
get_output_dim()
Source code in src/formed/integrations/flax/modules/vectorizers.py
161 162 163 | |
formed.integrations.flax.modules.weighters
¶
BaseLabelWeighter
¶
Bases: Module, Registrable, Generic[_ParamsT], ABC
Abstract base class for label weighters.
A LabelWeighter defines a strategy for assigning weights to each label based on model logits and true targets.
| CLASS TYPE PARAMETER | DESCRIPTION |
|---|---|
_ParamsT
|
Type of additional parameters used during weighting.
|
StaticLabelWeighter
¶
StaticLabelWeighter(weights)
Bases: BaseLabelWeighter
Label weighter that assigns static weights to each class.
| PARAMETER | DESCRIPTION |
|---|---|
weights
|
An array of shape (num_classes,) containing the weight for each class.
TYPE:
|
Source code in src/formed/integrations/flax/modules/weighters.py
53 54 55 | |
BalancedByDistributionLabelWeighter
¶
BalancedByDistributionLabelWeighter(
distribution, eps=1e-08
)
Bases: BaseLabelWeighter
Label weighter that balances classes based on their distribution.
| PARAMETER | DESCRIPTION |
|---|---|
distribution
|
An array of shape (num_classes,) representing the class distribution.
TYPE:
|
eps
|
A small epsilon value to avoid division by zero.
TYPE:
|
Source code in src/formed/integrations/flax/modules/weighters.py
76 77 78 | |
formed.integrations.flax.training.callbacks
¶
Training callbacks for monitoring and controlling Flax model training.
This module provides a callback system for Flax training, allowing custom logic to be executed at various points in the training loop. Callbacks can monitor metrics, save checkpoints, implement early stopping, and integrate with experiment tracking systems.
Key Components
- FlaxTrainingCallback: Base class for all callbacks
- EvaluationCallback: Computes metrics using custom evaluators
- EarlyStoppingCallback: Stops training based on metric improvements
- MlflowCallback: Logs metrics to MLflow
Features
- Hook points at training/epoch/batch start and end
- Metric computation and logging
- Model checkpointing
- Early stopping with patience
- MLflow integration
- Extensible for custom callbacks
Examples:
>>> from formed.integrations.flax import (
... FlaxTrainer,
... EarlyStoppingCallback,
... EvaluationCallback,
... MlflowCallback
... )
>>>
>>> trainer = FlaxTrainer(
... train_dataloader=train_loader,
... val_dataloader=val_loader,
... callbacks=[
... EvaluationCallback(my_evaluator),
... EarlyStoppingCallback(patience=5, metric="-loss"),
... MlflowCallback()
... ]
... )
FlaxTrainingCallback
¶
Bases: Registrable
Base class for training callbacks.
Callbacks provide hooks to execute custom logic at various points during training. Subclasses can override any hook method to implement custom behavior such as logging, checkpointing, or early stopping.
Hook execution order
- on_training_start - once at the beginning
- on_epoch_start - at the start of each epoch
- on_batch_start - before each training batch
- on_batch_end - after each training batch
- on_eval_start - before evaluation (returns evaluator)
- on_eval_end - after evaluation with computed metrics
- on_log - when metrics are logged
- on_epoch_end - at the end of each epoch
- on_training_end - once at the end (can modify final state)
Examples:
>>> @FlaxTrainingCallback.register("my_callback")
... class MyCallback(FlaxTrainingCallback):
... def on_epoch_end(self, trainer, model, state, epoch):
... print(f"Completed epoch {epoch} at step {state.step}")
on_training_start
¶
on_training_start(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
88 89 90 91 92 93 94 | |
on_training_end
¶
on_training_end(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
96 97 98 99 100 101 102 | |
on_epoch_start
¶
on_epoch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
104 105 106 107 108 109 110 111 | |
on_epoch_end
¶
on_epoch_end(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
113 114 115 116 117 118 119 120 | |
on_batch_start
¶
on_batch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
122 123 124 125 126 127 128 129 | |
on_batch_end
¶
on_batch_end(trainer, model, state, epoch, output)
Source code in src/formed/integrations/flax/training/callbacks.py
131 132 133 134 135 136 137 138 139 | |
on_eval_start
¶
on_eval_start(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | |
on_eval_end
¶
on_eval_end(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/flax/training/callbacks.py
162 163 164 165 166 167 168 169 170 | |
on_log
¶
on_log(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/flax/training/callbacks.py
172 173 174 175 176 177 178 179 180 | |
EvaluationCallback
¶
EvaluationCallback(evaluator)
Bases: FlaxTrainingCallback, Generic[ModelInputT, ModelOutputT]
Callback for computing metrics using a custom evaluator.
This callback integrates a custom evaluator into the training loop, resetting it before each evaluation phase and returning it for metric accumulation.
| PARAMETER | DESCRIPTION |
|---|---|
evaluator
|
Evaluator implementing the IEvaluator protocol.
TYPE:
|
Examples:
>>> from formed.integrations.ml.metrics import MulticlassAccuracy
>>>
>>> evaluator = MulticlassAccuracy()
>>> callback = EvaluationCallback(evaluator)
Source code in src/formed/integrations/flax/training/callbacks.py
202 203 | |
on_eval_start
¶
on_eval_start(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
205 206 207 208 209 210 211 212 | |
on_training_start
¶
on_training_start(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
88 89 90 91 92 93 94 | |
on_training_end
¶
on_training_end(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
96 97 98 99 100 101 102 | |
on_epoch_start
¶
on_epoch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
104 105 106 107 108 109 110 111 | |
on_epoch_end
¶
on_epoch_end(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
113 114 115 116 117 118 119 120 | |
on_batch_start
¶
on_batch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
122 123 124 125 126 127 128 129 | |
on_batch_end
¶
on_batch_end(trainer, model, state, epoch, output)
Source code in src/formed/integrations/flax/training/callbacks.py
131 132 133 134 135 136 137 138 139 | |
on_eval_end
¶
on_eval_end(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/flax/training/callbacks.py
162 163 164 165 166 167 168 169 170 | |
on_log
¶
on_log(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/flax/training/callbacks.py
172 173 174 175 176 177 178 179 180 | |
EarlyStoppingCallback
¶
EarlyStoppingCallback(patience=5, metric='-loss')
Bases: FlaxTrainingCallback
Callback for early stopping based on metric improvements.
This callback monitors a specified metric and stops training if it doesn't improve for a given number of evaluations (patience). The best model is automatically saved and restored at the end of training.
| PARAMETER | DESCRIPTION |
|---|---|
patience
|
Number of evaluations without improvement before stopping.
TYPE:
|
metric
|
Metric to monitor. Prefix with "-" to maximize (e.g., "-loss"), or "+" to minimize (e.g., "+error"). Default is "-loss".
TYPE:
|
Examples:
>>> # Stop if validation loss doesn't improve for 5 evaluations
>>> callback = EarlyStoppingCallback(patience=5, metric="-val/loss")
>>>
>>> # Stop if accuracy doesn't improve for 3 evaluations
>>> callback = EarlyStoppingCallback(patience=3, metric="+accuracy")
Note
The best model is saved to the step working directory and automatically restored when training ends early or completes.
Source code in src/formed/integrations/flax/training/callbacks.py
241 242 243 244 245 246 247 248 249 250 251 | |
on_training_start
¶
on_training_start(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
266 267 268 269 270 271 272 273 | |
on_eval_end
¶
on_eval_end(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/flax/training/callbacks.py
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 | |
on_training_end
¶
on_training_end(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
302 303 304 305 306 307 308 309 310 311 312 | |
on_epoch_start
¶
on_epoch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
104 105 106 107 108 109 110 111 | |
on_epoch_end
¶
on_epoch_end(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
113 114 115 116 117 118 119 120 | |
on_batch_start
¶
on_batch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
122 123 124 125 126 127 128 129 | |
on_batch_end
¶
on_batch_end(trainer, model, state, epoch, output)
Source code in src/formed/integrations/flax/training/callbacks.py
131 132 133 134 135 136 137 138 139 | |
on_eval_start
¶
on_eval_start(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | |
on_log
¶
on_log(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/flax/training/callbacks.py
172 173 174 175 176 177 178 179 180 | |
MlflowCallback
¶
MlflowCallback()
Bases: FlaxTrainingCallback
Callback for logging metrics to MLflow.
This callback automatically logs training and validation metrics to MLflow when used within a workflow step that has MLflow tracking enabled.
Examples:
>>> from formed.integrations.flax import FlaxTrainer, MlflowCallback
>>>
>>> trainer = FlaxTrainer(
... train_dataloader=train_loader,
... val_dataloader=val_loader,
... callbacks=[MlflowCallback()]
... )
Note
Requires the formed mlflow integration and must be used within a workflow step with MLflow tracking configured.
Source code in src/formed/integrations/flax/training/callbacks.py
337 338 339 340 | |
on_training_start
¶
on_training_start(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
342 343 344 345 346 347 348 349 350 351 352 353 354 355 | |
on_log
¶
on_log(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/flax/training/callbacks.py
357 358 359 360 361 362 363 364 365 366 367 368 | |
on_epoch_end
¶
on_epoch_end(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
370 371 372 373 374 375 376 377 378 | |
on_training_end
¶
on_training_end(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
96 97 98 99 100 101 102 | |
on_epoch_start
¶
on_epoch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
104 105 106 107 108 109 110 111 | |
on_batch_start
¶
on_batch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/flax/training/callbacks.py
122 123 124 125 126 127 128 129 | |
on_batch_end
¶
on_batch_end(trainer, model, state, epoch, output)
Source code in src/formed/integrations/flax/training/callbacks.py
131 132 133 134 135 136 137 138 139 | |
on_eval_start
¶
on_eval_start(trainer, model, state)
Source code in src/formed/integrations/flax/training/callbacks.py
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | |
on_eval_end
¶
on_eval_end(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/flax/training/callbacks.py
162 163 164 165 166 167 168 169 170 | |
formed.integrations.flax.training.engine
¶
Training engine abstractions for Flax models.
This module provides the training engine abstraction that defines how models are trained and evaluated. Engines handle loss computation, gradient calculation, and parameter updates.
Key Components
- FlaxTrainingEngine: Abstract base class for training engines
- DefaultFlaxTrainingEngine: Default implementation with automatic differentiation
Features
- Customizable loss functions
- Automatic gradient computation using JAX
- State creation and management
- Separate train and eval steps
- Compatible with FlaxTrainer and distributors
Examples:
>>> from formed.integrations.flax import DefaultFlaxTrainingEngine
>>>
>>> # Create engine with custom loss accessor
>>> engine = DefaultFlaxTrainingEngine(loss="total_loss")
>>>
>>> # Or with custom loss function
>>> def custom_loss(output):
... return output.loss + 0.1 * output.regularization
>>> engine = DefaultFlaxTrainingEngine(loss=custom_loss)
FlaxTrainingEngine
¶
Bases: ABC, Registrable, Generic[ModelInputT, ModelOutputT, ModelParamsT]
Abstract base class for Flax training engines.
A training engine defines how models are trained by implementing state creation, training steps, and evaluation steps. This allows for custom training loops and loss computations.
| CLASS TYPE PARAMETER | DESCRIPTION |
|---|---|
ModelInputT
|
Type of model input.
|
ModelOutputT
|
Type of model output.
|
ModelParamsT
|
Type of additional parameters.
|
create_state
abstractmethod
¶
create_state(rngs, trainer, model)
Create initial training state from model and trainer.
| PARAMETER | DESCRIPTION |
|---|---|
rngs
|
Random number generators.
TYPE:
|
trainer
|
Trainer instance.
TYPE:
|
model
|
Model to train.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
TrainState
|
Initial training state. |
Source code in src/formed/integrations/flax/training/engine.py
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | |
train_step
abstractmethod
¶
train_step(inputs, state, trainer)
Execute a single training step.
| PARAMETER | DESCRIPTION |
|---|---|
inputs
|
Batch of training inputs.
TYPE:
|
state
|
Current training state.
TYPE:
|
trainer
|
Trainer instance.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
tuple[TrainState, ModelOutputT]
|
Tuple of (updated_state, model_output). |
Source code in src/formed/integrations/flax/training/engine.py
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 | |
eval_step
abstractmethod
¶
eval_step(inputs, state, trainer)
Execute a single evaluation step.
| PARAMETER | DESCRIPTION |
|---|---|
inputs
|
Batch of evaluation inputs.
TYPE:
|
state
|
Current training state.
TYPE:
|
trainer
|
Trainer instance.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
ModelOutputT
|
Model output. |
Source code in src/formed/integrations/flax/training/engine.py
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | |
DefaultFlaxTrainingEngine
¶
DefaultFlaxTrainingEngine(
loss="loss", optimizer=adamw(0.001)
)
Bases: FlaxTrainingEngine[ModelInputT, ModelOutputT, ModelParamsT]
Default training engine using automatic differentiation.
This engine computes gradients using JAX's automatic differentiation and updates parameters using the provided optimizer. Loss is extracted from model output either by attribute name or custom function.
| PARAMETER | DESCRIPTION |
|---|---|
loss
|
Loss accessor - either attribute name (e.g., "loss") or callable that extracts loss from model output.
TYPE:
|
optimizer
|
Optax optimizer or transformation.
TYPE:
|
Examples:
>>> # Use output.loss attribute
>>> engine = DefaultFlaxTrainingEngine(loss="loss")
>>>
>>> # Use custom loss function
>>> engine = DefaultFlaxTrainingEngine(
... loss=lambda output: output.loss + 0.01 * output.regularization
... )
Source code in src/formed/integrations/flax/training/engine.py
150 151 152 153 154 155 156 157 158 159 160 | |
create_state
¶
create_state(rngs, trainer, model)
Source code in src/formed/integrations/flax/training/engine.py
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | |
train_step
¶
train_step(inputs, state, trainer)
Source code in src/formed/integrations/flax/training/engine.py
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 | |
eval_step
¶
eval_step(inputs, state, trainer)
Source code in src/formed/integrations/flax/training/engine.py
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 | |
formed.integrations.flax.training.exceptions
¶
StopEarly
¶
Bases: Exception
Raised to stop training early.
formed.integrations.flax.training.state
¶
Training state management for Flax NNX models.
This module extends Flax's TrainState to support NNX models by storing the graph definition and additional states (BatchStat, RngState) separately from trainable parameters.
TrainState
¶
Bases: TrainState, Generic[Node]
Extended training state for Flax NNX models.
This class extends flax.training.train_state.TrainState to work with Flax NNX models, storing the model's graph definition and additional states (like batch statistics and RNG states) separately from parameters.
| ATTRIBUTE | DESCRIPTION |
|---|---|
graphdef |
NNX graph definition describing the model structure.
TYPE:
|
additional_states |
Tuple of additional states (BatchStat, RngState, etc.).
TYPE:
|
params |
Trainable parameters (inherited from TrainState).
TYPE:
|
opt_state |
Optimizer state (inherited from TrainState).
TYPE:
|
step |
Training step counter (inherited from TrainState).
TYPE:
|
tx |
Optimizer transformation (inherited from TrainState).
TYPE:
|
Examples:
>>> # Create state from model
>>> graphdef, params, *states = nnx.split(model, nnx.Param, nnx.BatchStat)
>>> state = TrainState.create(
... apply_fn=None,
... graphdef=graphdef,
... additional_states=tuple(states),
... params=params,
... tx=optimizer
... )
>>>
>>> # Reconstruct model
>>> model = nnx.merge(state.graphdef, state.params, *state.additional_states)
formed.integrations.flax.training.trainer
¶
High-level trainer for Flax models.
This module provides the FlaxTrainer class, which orchestrates the complete training process for Flax models including data loading, optimization, evaluation, callbacks, and distributed training.
Key Features
- Flexible training loop with epoch and step-based logging/evaluation
- Support for callbacks at various training stages
- Distributed training via data parallelism
- Rich progress bars with training metrics
- Early stopping and checkpointing
- MLflow integration
Examples:
>>> from formed.integrations.flax import (
... FlaxTrainer,
... EvaluationCallback,
... EarlyStoppingCallback
... )
>>> from formed.integrations.ml import DataLoader, BasicBatchSampler
>>> import optax
>>>
>>> # Setup data loaders
>>> train_dataloader = DataLoader(
... sampler=BasicBatchSampler(batch_size=32, shuffle=True),
... collator=datamodule.batch
... )
>>>
>>> # Create trainer
>>> trainer = FlaxTrainer(
... train_dataloader=train_dataloader,
... val_dataloader=val_dataloader,
... optimizer=optax.adamw(learning_rate=1e-3),
... max_epochs=10,
... callbacks=[
... EvaluationCallback(my_evaluator),
... EarlyStoppingCallback(patience=3)
... ]
... )
>>>
>>> # Train model
>>> rngs = nnx.Rngs(42)
>>> state = trainer.train(rngs, model, train_dataset, val_dataset)
FlaxTrainer
¶
FlaxTrainer(
*,
train_dataloader,
val_dataloader=None,
engine=None,
callbacks=(),
distributor=None,
max_epochs=10,
eval_strategy="epoch",
eval_interval=1,
logging_strategy="epoch",
logging_interval=1,
logging_first_step=True,
train_prefix="train/",
val_prefix="val/",
)
Bases: Generic[ItemT, ModelInputT, ModelOutputT, ModelParamsT]
High-level trainer for Flax models.
FlaxTrainer provides a complete training loop with support for distributed training, callbacks, evaluation, and metric logging. It handles the coordination of data loading, model training, evaluation, and callback execution.
| CLASS TYPE PARAMETER | DESCRIPTION |
|---|---|
ItemT
|
Type of raw dataset items.
|
ModelInputT
|
Type of batched model inputs.
|
ModelOutputT
|
Type of model outputs.
|
ModelParamsT
|
Type of additional model parameters.
|
| PARAMETER | DESCRIPTION |
|---|---|
train_dataloader
|
Data loader for training dataset.
TYPE:
|
val_dataloader
|
Optional data loader for validation dataset.
TYPE:
|
engine
|
Training engine (defaults to DefaultFlaxTrainingEngine).
TYPE:
|
callbacks
|
Sequence of training callbacks.
TYPE:
|
distributor
|
Device distributor (defaults to SingleDeviceDistributor).
TYPE:
|
max_epochs
|
Maximum number of training epochs.
TYPE:
|
eval_strategy
|
When to evaluate - "epoch" or "step".
TYPE:
|
eval_interval
|
Evaluation interval (epochs or steps).
TYPE:
|
logging_strategy
|
When to log - "epoch" or "step".
TYPE:
|
logging_interval
|
Logging interval (epochs or steps).
TYPE:
|
logging_first_step
|
Whether to log after the first training step.
TYPE:
|
Examples:
>>> trainer = FlaxTrainer(
... train_dataloader=train_loader,
... val_dataloader=val_loader,
... max_epochs=10,
... eval_strategy="epoch",
... logging_strategy="step",
... logging_interval=100
... engine=DefaultFlaxTrainingEngine(
... optimizer=optax.adamw(1e-3),
... ),
... )
Source code in src/formed/integrations/flax/training/trainer.py
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | |
train
¶
train(
model,
train_dataset,
val_dataset=None,
state=None,
rngs=None,
)
Train a model on the provided datasets.
| PARAMETER | DESCRIPTION |
|---|---|
model
|
Model to train.
TYPE:
|
train_dataset
|
Sequence of training items.
TYPE:
|
val_dataset
|
Optional sequence of validation items.
TYPE:
|
state
|
Optional pre-initialized training state (for resuming).
TYPE:
|
rngs
|
Optional random number generators for initialization.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
TrainState
|
Final training state with trained parameters. |
| RAISES | DESCRIPTION |
|---|---|
ValueError
|
If val_dataset is provided but val_dataloader is not. |
Examples:
>>> state = trainer.train(
... model, train_items, val_items
... )
>>> # Reconstruct trained model
>>> trained_model = nnx.merge(
... state.graphdef, state.params, *state.additional_states
... )
Source code in src/formed/integrations/flax/training/trainer.py
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 | |
formed.integrations.flax.random
¶
use_rngs
¶
use_rngs(default=None)
Context manager to set and restore nnx RNGs.
This context manager allows temporarily setting the nnx RNGs used in Flax/nnx operations. It saves the current RNGs on entry and restores them on exit.
| YIELDS | DESCRIPTION |
|---|---|
Rngs
|
The current nnx RNGs within the context. |
Source code in src/formed/integrations/flax/random.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | |
require_rngs
¶
require_rngs()
Get the current nnx RNGs.
| RETURNS | DESCRIPTION |
|---|---|
Rngs
|
The current nnx RNGs. |
Source code in src/formed/integrations/flax/random.py
29 30 31 32 33 34 35 36 | |
formed.integrations.flax.utils
¶
PoolingMethod
module-attribute
¶
PoolingMethod = Literal[
"mean", "max", "min", "sum", "hier", "first", "last"
]
ensure_jax_array
¶
ensure_jax_array(x)
Source code in src/formed/integrations/flax/utils.py
12 13 14 15 | |
masked_pool
¶
masked_pool(
embeddings,
mask=None,
pooling="mean",
normalize=False,
window_size=None,
)
Pool embeddings with a mask.
| PARAMETER | DESCRIPTION |
|---|---|
embeddings
|
Embeddings to pool of shape (batch_size, sequence_length, embedding_size).
TYPE:
|
mask
|
Mask of shape (batch_size, sequence_length).
TYPE:
|
pooling
|
TYPE:
|
normalize
|
Whether to normalize the embeddings before pooling. Defaults to
TYPE:
|
window_size
|
Window size for hierarchical pooling. Defaults to
TYPE:
|
Source code in src/formed/integrations/flax/utils.py
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | |
sequence_distribute
¶
sequence_distribute(
inputs: Array,
) -> tuple[Array, tuple[int, int]]
sequence_distribute(
inputs: _MappingT, ignore: Sequence[str] = ...
) -> tuple[_MappingT, tuple[int, int]]
sequence_distribute(inputs, ignore=())
Source code in src/formed/integrations/flax/utils.py
127 128 129 130 131 132 133 134 135 136 137 138 139 | |
sequence_undistribute
¶
sequence_undistribute(
inputs: Array,
shape: tuple[int, int],
ignore: Sequence[str] = ...,
) -> Array
sequence_undistribute(
inputs: _MappingT,
shape: tuple[int, int],
ignore: Sequence[str] = ...,
) -> _MappingT
sequence_undistribute(inputs, shape, ignore=())
Source code in src/formed/integrations/flax/utils.py
158 159 160 161 162 163 164 165 166 167 168 | |
determine_ndim
¶
determine_ndim(first, *args)
Source code in src/formed/integrations/flax/utils.py
171 172 173 174 175 176 177 178 179 180 181 182 183 | |
formed.integrations.flax.workflow
¶
Workflow integration for Flax model training.
This module provides workflow steps for training Flax models, allowing them to be integrated into the formed workflow system with automatic caching and dependency tracking.
Available Steps
flax::train: Train a Flax model using the provided trainer.flax::evaluate: Evaluate a Flax model on a dataset.
Examples:
>>> from formed.integrations.flax import train_flax_model
>>>
>>> # In workflow configuration (jsonnet):
>>> # {
>>> # steps: {
>>> # train: {
>>> # type: "flax::train",
>>> # model: { type: "my_model", ... },
>>> # trainer: { type: "flax_trainer", ... },
>>> # train_dataset: { type: "ref", ref: "preprocess" },
>>> # random_seed: 42
>>> # }
>>> # }
>>> # }
FlaxModelFormat
¶
Bases: Format[BaseFlaxModel]
identifier
property
¶
identifier
Get the unique identifier for this format.
| RETURNS | DESCRIPTION |
|---|---|
str
|
Format identifier string. |
write
¶
write(artifact, directory)
Source code in src/formed/integrations/flax/workflow.py
66 67 68 69 70 71 72 73 74 75 | |
read
¶
read(directory)
Source code in src/formed/integrations/flax/workflow.py
77 78 79 80 81 82 83 84 85 86 87 88 89 90 | |
is_default_of
classmethod
¶
is_default_of(obj)
Check if this format is the default for the given object type.
| PARAMETER | DESCRIPTION |
|---|---|
obj
|
Object to check.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
bool
|
True if this format should be used by default for this type. |
Source code in src/formed/workflow/format.py
101 102 103 104 105 106 107 108 109 110 111 112 | |
train_flax_model
¶
train_flax_model(
model,
trainer,
train_dataset,
val_dataset=None,
random_seed=0,
)
Train a Flax model using the provided trainer.
This workflow step trains a Flax NNX model on the provided datasets, returning the trained model. The training process is cached based on the model architecture, trainer configuration, and dataset fingerprints.
| PARAMETER | DESCRIPTION |
|---|---|
model
|
Flax model to train.
TYPE:
|
trainer
|
Trainer configuration with dataloaders and callbacks.
TYPE:
|
train_dataset
|
Training dataset items.
TYPE:
|
val_dataset
|
Optional validation dataset items.
TYPE:
|
random_seed
|
Random seed for reproducibility.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
BaseFlaxModel
|
Trained Flax model with updated parameters. |
Examples:
>>> # Use in Python code
>>> trained_model = train_flax_model(
... model=my_model,
... trainer=trainer,
... train_dataset=train_data,
... val_dataset=val_data,
... random_seed=42
... )
Source code in src/formed/integrations/flax/workflow.py
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | |
evaluate_flax_model
¶
evaluate_flax_model(
model,
evaluator,
dataset,
dataloader,
params=None,
random_seed=None,
)
Evaluate a Flax model on a dataset using the provided evaluator.
| PARAMETER | DESCRIPTION |
|---|---|
model
|
Flax model to evaluate.
TYPE:
|
evaluator
|
Evaluator to compute metrics.
TYPE:
|
dataset
|
Dataset items for evaluation.
TYPE:
|
dataloader
|
DataLoader to convert items to model inputs.
TYPE:
|
params
|
Optional model parameters to use for evaluation.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
dict[str, float]
|
Dictionary of computed evaluation metrics. |
Source code in src/formed/integrations/flax/workflow.py
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | |