Skip to content

Torch

formed.integrations.torch.context

Context management for PyTorch operations.

use_device

use_device(device=None)

Context manager to set and restore the default PyTorch device.

This context manager allows temporarily setting the default device used in PyTorch operations (e.g., in ensure_torch_tensor). It saves the current device on entry and restores it on exit.

PARAMETER DESCRIPTION
device

Device to use within the context. Can be a torch.device, a string like "cuda:0" or "cpu", or None.

TYPE: str | device | None DEFAULT: None

YIELDS DESCRIPTION
device

The current device within the context.

Examples:

>>> import torch
>>> from formed.integrations.torch import use_device, ensure_torch_tensor
>>> import numpy as np
>>> with use_device("cuda:0" if torch.cuda.is_available() else "cpu"):
...     arr = np.array([1.0, 2.0, 3.0])
...     tensor = ensure_torch_tensor(arr)
...     print(tensor.device)
cuda:0  # or cpu if CUDA not available
Source code in src/formed/integrations/torch/context.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
@contextmanager
def use_device(device: str | torch.device | None = None) -> Iterator[torch.device]:
    """Context manager to set and restore the default PyTorch device.

    This context manager allows temporarily setting the default device
    used in PyTorch operations (e.g., in `ensure_torch_tensor`). It saves
    the current device on entry and restores it on exit.

    Args:
        device: Device to use within the context. Can be a torch.device,
            a string like `"cuda:0"` or `"cpu"`, or None.

    Yields:
        The current device within the context.

    Examples:
        >>> import torch
        >>> from formed.integrations.torch import use_device, ensure_torch_tensor
        >>> import numpy as np
        >>> with use_device("cuda:0" if torch.cuda.is_available() else "cpu"):
        ...     arr = np.array([1.0, 2.0, 3.0])
        ...     tensor = ensure_torch_tensor(arr)
        ...     print(tensor.device)
        cuda:0  # or cpu if CUDA not available
    """
    if device is None:
        if torch.cuda.is_available():
            device = "cuda:0"
        elif torch.backends.mps.is_available():
            device = "mps"
        else:
            device = "cpu"
    if isinstance(device, str):
        device = torch.device(device)
    token = _TORCH_DEVICE.set(torch.device(device))
    try:
        yield device
    finally:
        _TORCH_DEVICE.reset(token)

get_device

get_device()

Get the current default PyTorch device from context.

RETURNS DESCRIPTION
device | None

The current device set in the context, or None if not set.

Examples:

>>> from formed.integrations.torch import use_device, get_device
>>> with use_device("cuda:0"):
...     print(get_device())
cuda:0
Source code in src/formed/integrations/torch/context.py
53
54
55
56
57
58
59
60
61
62
63
64
65
def get_device() -> torch.device | None:
    """Get the current default PyTorch device from context.

    Returns:
        The current device set in the context, or `None` if not set.

    Examples:
        >>> from formed.integrations.torch import use_device, get_device
        >>> with use_device("cuda:0"):
        ...     print(get_device())
        cuda:0
    """
    return _TORCH_DEVICE.get()

formed.integrations.torch.dataloader

DataLoader utilities for PyTorch training.

This module provides convenient wrappers for creating PyTorch DataLoaders that work seamlessly with the formed training framework.

Examples:

>>> from formed.integrations.torch import DataLoader
>>>
>>> # Create a simple dataloader
>>> train_loader = DataLoader(
...     batch_size=32,
...     shuffle=True,
...     collate_fn=my_collate_fn
... )
>>>
>>> # Use with trainer
>>> trainer = TorchTrainer(
...     train_dataloader=train_loader,
...     ...
... )

ItemT module-attribute

ItemT = TypeVar('ItemT')

DataLoader

DataLoader(
    batch_size,
    shuffle=False,
    collate_fn=None,
    num_workers=0,
    drop_last=False,
    pin_memory=False,
    **kwargs,
)

Simple DataLoader wrapper for PyTorch training.

This class wraps PyTorch's DataLoader with a simpler interface that works with the formed training framework.

PARAMETER DESCRIPTION
batch_size

Number of samples per batch.

TYPE: int

shuffle

Whether to shuffle the data at every epoch.

TYPE: bool DEFAULT: False

collate_fn

Function to collate samples into batches.

TYPE: Callable[[list[ItemT]], ModelInputT] | None DEFAULT: None

num_workers

Number of subprocesses for data loading.

TYPE: int DEFAULT: 0

drop_last

Whether to drop the last incomplete batch.

TYPE: bool DEFAULT: False

pin_memory

If True, tensors are copied to CUDA pinned memory.

TYPE: bool DEFAULT: False

**kwargs

Additional arguments passed to torch.utils.data.DataLoader.

DEFAULT: {}

Examples:

>>> def collate_fn(batch):
...     # Convert list of samples to batch tensors
...     return {"features": torch.stack([x["features"] for x in batch])}
>>>
>>> loader = DataLoader(
...     batch_size=32,
...     shuffle=True,
...     collate_fn=collate_fn
... )
Source code in src/formed/integrations/torch/dataloader.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def __init__(
    self,
    batch_size: int,
    shuffle: bool = False,
    collate_fn: Callable[[list[ItemT]], ModelInputT] | None = None,
    num_workers: int = 0,
    drop_last: bool = False,
    pin_memory: bool = False,
    **kwargs,
) -> None:
    self.batch_size = batch_size
    self.shuffle = shuffle
    self.collate_fn = collate_fn
    self.num_workers = num_workers
    self.drop_last = drop_last
    self.pin_memory = pin_memory
    self.kwargs = kwargs

batch_size instance-attribute

batch_size = batch_size

shuffle instance-attribute

shuffle = shuffle

collate_fn instance-attribute

collate_fn = collate_fn

num_workers instance-attribute

num_workers = num_workers

drop_last instance-attribute

drop_last = drop_last

pin_memory instance-attribute

pin_memory = pin_memory

kwargs instance-attribute

kwargs = kwargs

formed.integrations.torch.distributors

Distributed computing abstractions for PyTorch models.

This module provides abstractions for distributed training across multiple devices, supporting both single-device and data-parallel training strategies.

Key Components
  • BaseDistributor: Abstract interface for device distribution strategies
  • SingleDeviceDistributor: No-op distributor for single-device training
  • DataParallelDistributor: Data-parallel training using torch.nn.DataParallel
Features
  • Transparent device sharding and replication
  • Reduction operations (mean, sum) across devices
  • Compatible with TorchTrainer

Examples:

>>> from formed.integrations.torch import DataParallelDistributor
>>> import torch
>>>
>>> # Create data-parallel distributor for all available GPUs
>>> distributor = DataParallelDistributor()
>>>
>>> # Shard batch across devices
>>> sharded_batch = distributor.shard(batch)

BaseDistributor

Bases: Registrable, ABC, Generic[ModelInputT]

Abstract base class for device distribution strategies.

BaseDistributor defines the interface for distributing computations across devices in a PyTorch training pipeline. It provides a unified API for single-device, data-parallel, and distributed data-parallel training.

CLASS TYPE PARAMETER DESCRIPTION
ModelInputT

Type of model input data.

Key Methods
  • device: Primary device for computation
  • is_main_process: Whether this is the main process (for logging, saving, etc.)
  • wrap_model: Wrap model for distributed training
  • prepare_data_loader: Prepare data loader with appropriate sampler
  • reduce: Reduce tensor across devices/processes
  • barrier: Synchronize all processes
  • all_gather: Gather tensors from all processes

device abstractmethod property

device

Primary device for computation.

is_main_process property

is_main_process

Whether this is the main process.

The main process is responsible for: - Logging to console - Saving models and checkpoints - Writing metrics to file

RETURNS DESCRIPTION
bool

True if this is the main process (rank 0), False otherwise.

world_size property

world_size

Total number of processes/devices.

RETURNS DESCRIPTION
int

Number of processes in distributed training, or 1 for single device.

rank property

rank

Global rank of this process.

RETURNS DESCRIPTION
int

Rank of this process (0 for main process).

wrap_model

wrap_model(model)

Wrap model for distributed training.

PARAMETER DESCRIPTION
model

Model to wrap.

TYPE: Module

RETURNS DESCRIPTION
Module

Wrapped model (DataParallel, DDP, or unchanged).

Source code in src/formed/integrations/torch/distributors.py
105
106
107
108
109
110
111
112
113
114
115
def wrap_model(self, model: nn.Module) -> nn.Module:
    """Wrap model for distributed training.

    Args:
        model: Model to wrap.

    Returns:
        Wrapped model (DataParallel, DDP, or unchanged).

    """
    return model

prepare_data_loader

prepare_data_loader(
    dataset,
    batch_size,
    shuffle=False,
    num_workers=0,
    drop_last=False,
    **kwargs,
)

Prepare data loader with appropriate sampler for this distributor.

For single device: uses default sampler For DataParallel: uses default sampler (data split happens in forward) For DDP: uses DistributedSampler to split data across processes

PARAMETER DESCRIPTION
dataset

Dataset to load.

TYPE: Sequence

batch_size

Batch size per device/process.

TYPE: int

shuffle

Whether to shuffle data.

TYPE: bool DEFAULT: False

num_workers

Number of worker processes.

TYPE: int DEFAULT: 0

drop_last

Whether to drop last incomplete batch.

TYPE: bool DEFAULT: False

**kwargs

Additional arguments for DataLoader.

DEFAULT: {}

RETURNS DESCRIPTION
DataLoader

Configured DataLoader.

Source code in src/formed/integrations/torch/distributors.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def prepare_data_loader(
    self,
    dataset: Sequence,
    batch_size: int,
    shuffle: bool = False,
    num_workers: int = 0,
    drop_last: bool = False,
    **kwargs,
) -> torch.utils.data.DataLoader:
    """Prepare data loader with appropriate sampler for this distributor.

    For single device: uses default sampler
    For DataParallel: uses default sampler (data split happens in forward)
    For DDP: uses DistributedSampler to split data across processes

    Args:
        dataset: Dataset to load.
        batch_size: Batch size per device/process.
        shuffle: Whether to shuffle data.
        num_workers: Number of worker processes.
        drop_last: Whether to drop last incomplete batch.
        **kwargs: Additional arguments for DataLoader.

    Returns:
        Configured DataLoader.

    """
    from torch.utils.data import DataLoader

    return DataLoader(
        dataset,  # type: ignore[arg-type]
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        drop_last=drop_last,
        **kwargs,
    )

reduce abstractmethod

reduce(tensor, op='mean')

Reduce a tensor across devices/processes.

PARAMETER DESCRIPTION
tensor

Tensor to reduce.

TYPE: _TensorT

op

Reduction operation ("mean" or "sum").

TYPE: _ReduceOp DEFAULT: 'mean'

RETURNS DESCRIPTION
_TensorT

Reduced tensor.

Source code in src/formed/integrations/torch/distributors.py
155
156
157
158
159
160
161
162
163
164
165
166
167
@abc.abstractmethod
def reduce(self, tensor: _TensorT, op: _ReduceOp = "mean") -> _TensorT:
    """Reduce a tensor across devices/processes.

    Args:
        tensor: Tensor to reduce.
        op: Reduction operation (`"mean"` or `"sum"`).

    Returns:
        Reduced tensor.

    """
    raise NotImplementedError

barrier

barrier()

Synchronize all processes.

This is a no-op for single device and DataParallel. For DDP, it blocks until all processes reach this point.

Source code in src/formed/integrations/torch/distributors.py
169
170
171
172
173
174
175
176
def barrier(self) -> None:
    """Synchronize all processes.

    This is a no-op for single device and DataParallel.
    For DDP, it blocks until all processes reach this point.

    """
    pass

all_gather

all_gather(tensor)

Gather tensors from all processes.

PARAMETER DESCRIPTION
tensor

Tensor to gather.

TYPE: Tensor

RETURNS DESCRIPTION
list[Tensor]

List of tensors from all processes.

list[Tensor]

For single device/DataParallel, returns [tensor].

Source code in src/formed/integrations/torch/distributors.py
178
179
180
181
182
183
184
185
186
187
188
189
def all_gather(self, tensor: torch.Tensor) -> list[torch.Tensor]:
    """Gather tensors from all processes.

    Args:
        tensor: Tensor to gather.

    Returns:
        List of tensors from all processes.
        For single device/DataParallel, returns [tensor].

    """
    return [tensor]

cleanup

cleanup()

Cleanup resources (e.g., distributed process group).

This is a no-op for single device and DataParallel. For DDP, destroys the process group.

Source code in src/formed/integrations/torch/distributors.py
191
192
193
194
195
196
197
198
def cleanup(self) -> None:
    """Cleanup resources (e.g., distributed process group).

    This is a no-op for single device and DataParallel.
    For DDP, destroys the process group.

    """
    pass

SingleDeviceDistributor

SingleDeviceDistributor(device=None)

Bases: BaseDistributor[ModelInputT]

Distributor for single-device training.

This distributor operates on a single device without any distribution. All shard, replicate, and unreplicate operations are no-ops.

PARAMETER DESCRIPTION
device

Device to use (default: "cuda" if available, else "cpu").

TYPE: Optional[Union[str, device]] DEFAULT: None

Examples:

>>> distributor = SingleDeviceDistributor(device="cuda:0")
>>> model = model.to(distributor.device)
Source code in src/formed/integrations/torch/distributors.py
217
218
219
220
def __init__(self, device: Optional[Union[str, torch.device]] = None) -> None:
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    self._device = torch.device(device)

device property

device

is_main_process property

is_main_process

Whether this is the main process.

The main process is responsible for: - Logging to console - Saving models and checkpoints - Writing metrics to file

RETURNS DESCRIPTION
bool

True if this is the main process (rank 0), False otherwise.

world_size property

world_size

Total number of processes/devices.

RETURNS DESCRIPTION
int

Number of processes in distributed training, or 1 for single device.

rank property

rank

Global rank of this process.

RETURNS DESCRIPTION
int

Rank of this process (0 for main process).

reduce

reduce(tensor, op='mean')

Return tensor unchanged (no reduction needed for single device).

PARAMETER DESCRIPTION
tensor

Input tensor.

TYPE: _TensorT

op

Reduction operation (ignored).

TYPE: _ReduceOp DEFAULT: 'mean'

RETURNS DESCRIPTION
_TensorT

Input tensor unchanged.

Source code in src/formed/integrations/torch/distributors.py
226
227
228
229
230
231
232
233
234
235
236
237
def reduce(self, tensor: _TensorT, op: _ReduceOp = "mean") -> _TensorT:
    """Return tensor unchanged (no reduction needed for single device).

    Args:
        tensor: Input tensor.
        op: Reduction operation (ignored).

    Returns:
        Input tensor unchanged.

    """
    return tensor

wrap_model

wrap_model(model)

Wrap model for distributed training.

PARAMETER DESCRIPTION
model

Model to wrap.

TYPE: Module

RETURNS DESCRIPTION
Module

Wrapped model (DataParallel, DDP, or unchanged).

Source code in src/formed/integrations/torch/distributors.py
105
106
107
108
109
110
111
112
113
114
115
def wrap_model(self, model: nn.Module) -> nn.Module:
    """Wrap model for distributed training.

    Args:
        model: Model to wrap.

    Returns:
        Wrapped model (DataParallel, DDP, or unchanged).

    """
    return model

prepare_data_loader

prepare_data_loader(
    dataset,
    batch_size,
    shuffle=False,
    num_workers=0,
    drop_last=False,
    **kwargs,
)

Prepare data loader with appropriate sampler for this distributor.

For single device: uses default sampler For DataParallel: uses default sampler (data split happens in forward) For DDP: uses DistributedSampler to split data across processes

PARAMETER DESCRIPTION
dataset

Dataset to load.

TYPE: Sequence

batch_size

Batch size per device/process.

TYPE: int

shuffle

Whether to shuffle data.

TYPE: bool DEFAULT: False

num_workers

Number of worker processes.

TYPE: int DEFAULT: 0

drop_last

Whether to drop last incomplete batch.

TYPE: bool DEFAULT: False

**kwargs

Additional arguments for DataLoader.

DEFAULT: {}

RETURNS DESCRIPTION
DataLoader

Configured DataLoader.

Source code in src/formed/integrations/torch/distributors.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def prepare_data_loader(
    self,
    dataset: Sequence,
    batch_size: int,
    shuffle: bool = False,
    num_workers: int = 0,
    drop_last: bool = False,
    **kwargs,
) -> torch.utils.data.DataLoader:
    """Prepare data loader with appropriate sampler for this distributor.

    For single device: uses default sampler
    For DataParallel: uses default sampler (data split happens in forward)
    For DDP: uses DistributedSampler to split data across processes

    Args:
        dataset: Dataset to load.
        batch_size: Batch size per device/process.
        shuffle: Whether to shuffle data.
        num_workers: Number of worker processes.
        drop_last: Whether to drop last incomplete batch.
        **kwargs: Additional arguments for DataLoader.

    Returns:
        Configured DataLoader.

    """
    from torch.utils.data import DataLoader

    return DataLoader(
        dataset,  # type: ignore[arg-type]
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        drop_last=drop_last,
        **kwargs,
    )

barrier

barrier()

Synchronize all processes.

This is a no-op for single device and DataParallel. For DDP, it blocks until all processes reach this point.

Source code in src/formed/integrations/torch/distributors.py
169
170
171
172
173
174
175
176
def barrier(self) -> None:
    """Synchronize all processes.

    This is a no-op for single device and DataParallel.
    For DDP, it blocks until all processes reach this point.

    """
    pass

all_gather

all_gather(tensor)

Gather tensors from all processes.

PARAMETER DESCRIPTION
tensor

Tensor to gather.

TYPE: Tensor

RETURNS DESCRIPTION
list[Tensor]

List of tensors from all processes.

list[Tensor]

For single device/DataParallel, returns [tensor].

Source code in src/formed/integrations/torch/distributors.py
178
179
180
181
182
183
184
185
186
187
188
189
def all_gather(self, tensor: torch.Tensor) -> list[torch.Tensor]:
    """Gather tensors from all processes.

    Args:
        tensor: Tensor to gather.

    Returns:
        List of tensors from all processes.
        For single device/DataParallel, returns [tensor].

    """
    return [tensor]

cleanup

cleanup()

Cleanup resources (e.g., distributed process group).

This is a no-op for single device and DataParallel. For DDP, destroys the process group.

Source code in src/formed/integrations/torch/distributors.py
191
192
193
194
195
196
197
198
def cleanup(self) -> None:
    """Cleanup resources (e.g., distributed process group).

    This is a no-op for single device and DataParallel.
    For DDP, destroys the process group.

    """
    pass

DataParallelDistributor

DataParallelDistributor(
    device_ids=None, output_device=None
)

Bases: BaseDistributor[ModelInputT]

Distributor for data-parallel training across multiple GPUs.

This distributor uses torch.nn.DataParallel to execute the same computation on different data shards across multiple GPUs. Data is automatically sharded along the batch dimension.

PARAMETER DESCRIPTION
device_ids

List of GPU device IDs to use. Defaults to all available GPUs.

TYPE: Optional[list[int]] DEFAULT: None

output_device

Device for outputs. Defaults to device_ids[0].

TYPE: Optional[int] DEFAULT: None

Examples:

>>> # Train on GPUs 0 and 1 with data parallelism
>>> distributor = DataParallelDistributor(device_ids=[0, 1])
>>>
>>> # Wrap model for data parallel training
>>> model = distributor.wrap_model(model)
Note

Batch size must be divisible by the number of devices for proper sharding.

Source code in src/formed/integrations/torch/distributors.py
264
265
266
267
268
269
270
271
272
273
274
275
276
def __init__(
    self,
    device_ids: Optional[list[int]] = None,
    output_device: Optional[int] = None,
) -> None:
    if device_ids is None:
        device_ids = list(range(torch.cuda.device_count()))
    if not device_ids:
        raise ValueError("No GPU devices available for DataParallelDistributor")

    self._device_ids = device_ids
    self._output_device = output_device if output_device is not None else device_ids[0]
    self._device = torch.device(f"cuda:{self._output_device}")

device property

device

is_main_process property

is_main_process

Whether this is the main process.

The main process is responsible for: - Logging to console - Saving models and checkpoints - Writing metrics to file

RETURNS DESCRIPTION
bool

True if this is the main process (rank 0), False otherwise.

world_size property

world_size

Total number of processes/devices.

RETURNS DESCRIPTION
int

Number of processes in distributed training, or 1 for single device.

rank property

rank

Global rank of this process.

RETURNS DESCRIPTION
int

Rank of this process (0 for main process).

wrap_model

wrap_model(model)

Wrap model with DataParallel.

PARAMETER DESCRIPTION
model

Model to wrap.

TYPE: Module

RETURNS DESCRIPTION
Module

DataParallel wrapped model.

Source code in src/formed/integrations/torch/distributors.py
282
283
284
285
286
287
288
289
290
291
292
def wrap_model(self, model: nn.Module) -> nn.Module:
    """Wrap model with `DataParallel`.

    Args:
        model: Model to wrap.

    Returns:
        `DataParallel` wrapped model.

    """
    return cast(nn.Module, nn.DataParallel(model, device_ids=self._device_ids, output_device=self._output_device))

reduce

reduce(tensor, op='mean')

Reduce tensor across devices.

PARAMETER DESCRIPTION
tensor

Tensor to reduce across device dimension.

TYPE: _TensorT

op

Reduction operation - "sum" or "mean".

TYPE: _ReduceOp DEFAULT: 'mean'

RETURNS DESCRIPTION
_TensorT

Reduced tensor.

RAISES DESCRIPTION
ValueError

If unsupported reduction operation is specified.

Source code in src/formed/integrations/torch/distributors.py
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
def reduce(self, tensor: _TensorT, op: _ReduceOp = "mean") -> _TensorT:
    """Reduce tensor across devices.

    Args:
        tensor: Tensor to reduce across device dimension.
        op: Reduction operation - `"sum"` or `"mean"`.

    Returns:
        Reduced tensor.

    Raises:
        ValueError: If unsupported reduction operation is specified.

    """
    if op == "sum":
        return cast(_TensorT, tensor.sum())
    elif op == "mean":
        return cast(_TensorT, tensor.mean())
    raise ValueError(f"Unsupported reduce operation: {op}")

prepare_data_loader

prepare_data_loader(
    dataset,
    batch_size,
    shuffle=False,
    num_workers=0,
    drop_last=False,
    **kwargs,
)

Prepare data loader with appropriate sampler for this distributor.

For single device: uses default sampler For DataParallel: uses default sampler (data split happens in forward) For DDP: uses DistributedSampler to split data across processes

PARAMETER DESCRIPTION
dataset

Dataset to load.

TYPE: Sequence

batch_size

Batch size per device/process.

TYPE: int

shuffle

Whether to shuffle data.

TYPE: bool DEFAULT: False

num_workers

Number of worker processes.

TYPE: int DEFAULT: 0

drop_last

Whether to drop last incomplete batch.

TYPE: bool DEFAULT: False

**kwargs

Additional arguments for DataLoader.

DEFAULT: {}

RETURNS DESCRIPTION
DataLoader

Configured DataLoader.

Source code in src/formed/integrations/torch/distributors.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def prepare_data_loader(
    self,
    dataset: Sequence,
    batch_size: int,
    shuffle: bool = False,
    num_workers: int = 0,
    drop_last: bool = False,
    **kwargs,
) -> torch.utils.data.DataLoader:
    """Prepare data loader with appropriate sampler for this distributor.

    For single device: uses default sampler
    For DataParallel: uses default sampler (data split happens in forward)
    For DDP: uses DistributedSampler to split data across processes

    Args:
        dataset: Dataset to load.
        batch_size: Batch size per device/process.
        shuffle: Whether to shuffle data.
        num_workers: Number of worker processes.
        drop_last: Whether to drop last incomplete batch.
        **kwargs: Additional arguments for DataLoader.

    Returns:
        Configured DataLoader.

    """
    from torch.utils.data import DataLoader

    return DataLoader(
        dataset,  # type: ignore[arg-type]
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        drop_last=drop_last,
        **kwargs,
    )

barrier

barrier()

Synchronize all processes.

This is a no-op for single device and DataParallel. For DDP, it blocks until all processes reach this point.

Source code in src/formed/integrations/torch/distributors.py
169
170
171
172
173
174
175
176
def barrier(self) -> None:
    """Synchronize all processes.

    This is a no-op for single device and DataParallel.
    For DDP, it blocks until all processes reach this point.

    """
    pass

all_gather

all_gather(tensor)

Gather tensors from all processes.

PARAMETER DESCRIPTION
tensor

Tensor to gather.

TYPE: Tensor

RETURNS DESCRIPTION
list[Tensor]

List of tensors from all processes.

list[Tensor]

For single device/DataParallel, returns [tensor].

Source code in src/formed/integrations/torch/distributors.py
178
179
180
181
182
183
184
185
186
187
188
189
def all_gather(self, tensor: torch.Tensor) -> list[torch.Tensor]:
    """Gather tensors from all processes.

    Args:
        tensor: Tensor to gather.

    Returns:
        List of tensors from all processes.
        For single device/DataParallel, returns [tensor].

    """
    return [tensor]

cleanup

cleanup()

Cleanup resources (e.g., distributed process group).

This is a no-op for single device and DataParallel. For DDP, destroys the process group.

Source code in src/formed/integrations/torch/distributors.py
191
192
193
194
195
196
197
198
def cleanup(self) -> None:
    """Cleanup resources (e.g., distributed process group).

    This is a no-op for single device and DataParallel.
    For DDP, destroys the process group.

    """
    pass

DistributedDataParallelDistributor

DistributedDataParallelDistributor(
    backend=None,
    init_method="env://",
    world_size=None,
    rank=None,
    local_rank=None,
    find_unused_parameters=False,
    broadcast_buffers=True,
    bucket_cap_mb=25,
)

Bases: BaseDistributor[ModelInputT]

Distributor for distributed data-parallel training using DDP.

This distributor uses torch.nn.parallel.DistributedDataParallel to execute training across multiple processes and devices. This is more efficient than DataParallel for multi-GPU training as it uses one process per GPU.

PARAMETER DESCRIPTION
backend

Backend to use for distributed training ("nccl", "gloo", "mpi"). Defaults to "nccl" for GPU and "gloo" for CPU.

TYPE: Optional[str] DEFAULT: None

init_method

URL specifying how to initialize the process group. Defaults to "env://" which uses environment variables.

TYPE: str DEFAULT: 'env://'

world_size

Total number of processes. If None, reads from environment.

TYPE: Optional[int] DEFAULT: None

rank

Rank of this process. If None, reads from environment.

TYPE: Optional[int] DEFAULT: None

local_rank

Local rank on this machine. If None, uses rank.

TYPE: Optional[int] DEFAULT: None

find_unused_parameters

Whether to find unused parameters. Default False.

TYPE: bool DEFAULT: False

broadcast_buffers

Whether to broadcast buffers. Default True.

TYPE: bool DEFAULT: True

bucket_cap_mb

Bucket size in MB for gradient allreduce. Default 25.

TYPE: int DEFAULT: 25

Environment Variables
  • RANK: Global rank of the process
  • LOCAL_RANK: Local rank on the machine
  • WORLD_SIZE: Total number of processes
  • MASTER_ADDR: Address of the master node
  • MASTER_PORT: Port of the master node

Examples:

>>> # On each process, initialize the distributor
>>> distributor = DistributedDataParallelDistributor(
...     backend="nccl",
...     init_method="env://"
... )
>>>
>>> # Wrap model with DDP
>>> model = distributor.wrap_model(model)
>>>
>>> # Train as usual - gradients are automatically synchronized
Note
  • Requires launching multiple processes (e.g., using torch.distributed.launch)
  • Each process should initialize its own distributor
  • Batch size should be per-process batch size
Source code in src/formed/integrations/torch/distributors.py
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
def __init__(
    self,
    backend: Optional[str] = None,
    init_method: str = "env://",
    world_size: Optional[int] = None,
    rank: Optional[int] = None,
    local_rank: Optional[int] = None,
    find_unused_parameters: bool = False,
    broadcast_buffers: bool = True,
    bucket_cap_mb: int = 25,
) -> None:
    import os

    import torch.distributed as dist

    # Determine backend
    if backend is None:
        backend = "nccl" if torch.cuda.is_available() else "gloo"

    # Get rank and world_size from environment if not provided
    if rank is None:
        rank = int(os.environ.get("RANK", 0))
    if world_size is None:
        world_size = int(os.environ.get("WORLD_SIZE", 1))
    if local_rank is None:
        local_rank = int(os.environ.get("LOCAL_RANK", rank))

    self._backend = backend
    self._init_method = init_method
    self._world_size = world_size
    self._rank = rank
    self._local_rank = local_rank
    self._find_unused_parameters = find_unused_parameters
    self._broadcast_buffers = broadcast_buffers
    self._bucket_cap_mb = bucket_cap_mb

    # Initialize process group if not already initialized
    if not dist.is_initialized():
        try:
            dist.init_process_group(
                backend=backend,
                init_method=init_method,
                world_size=world_size,
                rank=rank,
            )
        except Exception as e:
            raise RuntimeError(
                f"Failed to initialize distributed process group. "
                f"Backend: {backend}, init_method: {init_method}, "
                f"world_size: {world_size}, rank: {rank}. "
                f"Error: {e}. "
                "Please ensure all processes are launched correctly using torchrun "
                "and required environment variables (RANK, WORLD_SIZE, MASTER_ADDR, MASTER_PORT) are set."
            ) from e

    # Set device based on local rank
    if torch.cuda.is_available():
        self._device = torch.device(f"cuda:{local_rank}")
        torch.cuda.set_device(self._device)
    else:
        self._device = torch.device("cpu")

device property

device

is_main_process property

is_main_process

Whether this is the main process (rank 0).

rank property

rank

Global rank of this process.

local_rank property

local_rank

Local rank on this machine.

world_size property

world_size

Total number of processes.

wrap_model

wrap_model(model)

Wrap model with DistributedDataParallel.

PARAMETER DESCRIPTION
model

Model to wrap.

TYPE: Module

RETURNS DESCRIPTION
Module

DDP wrapped model.

Source code in src/formed/integrations/torch/distributors.py
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
def wrap_model(self, model: nn.Module) -> nn.Module:
    """Wrap model with DistributedDataParallel.

    Args:
        model: Model to wrap.

    Returns:
        DDP wrapped model.

    """
    return cast(
        nn.Module,
        nn.parallel.DistributedDataParallel(
            model,
            device_ids=[self._local_rank] if torch.cuda.is_available() else None,
            output_device=self._local_rank if torch.cuda.is_available() else None,
            find_unused_parameters=self._find_unused_parameters,
            broadcast_buffers=self._broadcast_buffers,
            bucket_cap_mb=self._bucket_cap_mb,
        ),
    )

prepare_data_loader

prepare_data_loader(
    dataset,
    batch_size,
    shuffle=False,
    num_workers=0,
    drop_last=False,
    **kwargs,
)

Prepare data loader with DistributedSampler for DDP.

PARAMETER DESCRIPTION
dataset

Dataset to load.

TYPE: Sequence

batch_size

Batch size per process.

TYPE: int

shuffle

Whether to shuffle data.

TYPE: bool DEFAULT: False

num_workers

Number of worker processes.

TYPE: int DEFAULT: 0

drop_last

Whether to drop last incomplete batch.

TYPE: bool DEFAULT: False

**kwargs

Additional arguments for DataLoader.

DEFAULT: {}

RETURNS DESCRIPTION
DataLoader

DataLoader with DistributedSampler.

Source code in src/formed/integrations/torch/distributors.py
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
def prepare_data_loader(
    self,
    dataset: Sequence,
    batch_size: int,
    shuffle: bool = False,
    num_workers: int = 0,
    drop_last: bool = False,
    **kwargs,
) -> torch.utils.data.DataLoader:
    """Prepare data loader with DistributedSampler for DDP.

    Args:
        dataset: Dataset to load.
        batch_size: Batch size per process.
        shuffle: Whether to shuffle data.
        num_workers: Number of worker processes.
        drop_last: Whether to drop last incomplete batch.
        **kwargs: Additional arguments for DataLoader.

    Returns:
        DataLoader with DistributedSampler.

    """
    from torch.utils.data import DataLoader
    from torch.utils.data.distributed import DistributedSampler

    sampler = DistributedSampler(
        dataset,  # type: ignore[arg-type]
        num_replicas=self._world_size,
        rank=self._rank,
        shuffle=shuffle,
        drop_last=drop_last,
    )

    return DataLoader(
        dataset,  # type: ignore[arg-type]
        batch_size=batch_size,
        sampler=sampler,
        num_workers=num_workers,
        **kwargs,
    )

reduce

reduce(tensor, op='mean')

Reduce tensor across all processes.

PARAMETER DESCRIPTION
tensor

Tensor to reduce.

TYPE: _TensorT

op

Reduction operation - "sum" or "mean".

TYPE: _ReduceOp DEFAULT: 'mean'

RETURNS DESCRIPTION
_TensorT

Reduced tensor.

RAISES DESCRIPTION
ValueError

If unsupported reduction operation is specified.

Source code in src/formed/integrations/torch/distributors.py
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
def reduce(self, tensor: _TensorT, op: _ReduceOp = "mean") -> _TensorT:
    """Reduce tensor across all processes.

    Args:
        tensor: Tensor to reduce.
        op: Reduction operation - `"sum"` or `"mean"`.

    Returns:
        Reduced tensor.

    Raises:
        ValueError: If unsupported reduction operation is specified.

    """
    import torch.distributed as dist

    if op == "sum":
        dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
        return tensor
    elif op == "mean":
        dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
        return cast(_TensorT, tensor / self._world_size)
    raise ValueError(f"Unsupported reduce operation: {op}")

all_gather

all_gather(tensor)

Gather tensors from all processes.

PARAMETER DESCRIPTION
tensor

Tensor to gather.

TYPE: Tensor

RETURNS DESCRIPTION
list[Tensor]

List of tensors from all processes.

Source code in src/formed/integrations/torch/distributors.py
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
def all_gather(self, tensor: torch.Tensor) -> list[torch.Tensor]:
    """Gather tensors from all processes.

    Args:
        tensor: Tensor to gather.

    Returns:
        List of tensors from all processes.

    """
    import torch.distributed as dist

    gathered = [torch.zeros_like(tensor) for _ in range(self._world_size)]
    dist.all_gather(gathered, tensor)
    return gathered

barrier

barrier()

Synchronize all processes.

This creates a barrier that blocks until all processes reach this point.

Source code in src/formed/integrations/torch/distributors.py
551
552
553
554
555
556
557
558
559
def barrier(self) -> None:
    """Synchronize all processes.

    This creates a barrier that blocks until all processes reach this point.

    """
    import torch.distributed as dist

    dist.barrier()

cleanup

cleanup()

Cleanup distributed process group.

This should be called at the end of training.

Source code in src/formed/integrations/torch/distributors.py
561
562
563
564
565
566
567
568
569
570
def cleanup(self) -> None:
    """Cleanup distributed process group.

    This should be called at the end of training.

    """
    import torch.distributed as dist

    if dist.is_initialized():
        dist.destroy_process_group()

formed.integrations.torch.initializers

BaseTensorInitializer

Bases: Registrable

UniformTensorInitializer

UniformTensorInitializer(shape, low=0.0, high=1.0)

Bases: BaseTensorInitializer

Source code in src/formed/integrations/torch/initializers.py
14
15
16
17
def __init__(self, shape: Sequence[int], low: float = 0.0, high: float = 1.0):
    self._shape = shape
    self._low = low
    self._high = high

NormalTensorInitializer

NormalTensorInitializer(shape, mean=0.0, std=1.0)

Bases: BaseTensorInitializer

Source code in src/formed/integrations/torch/initializers.py
25
26
27
28
def __init__(self, shape: Sequence[int], mean: float = 0.0, std: float = 1.0):
    self._shape = shape
    self._mean = mean
    self._std = std

XavierUniformTensorInitializer

XavierUniformTensorInitializer(shape, gain=1.0)

Bases: BaseTensorInitializer

Source code in src/formed/integrations/torch/initializers.py
36
37
38
def __init__(self, shape: Sequence[int], gain: float = 1.0):
    self._shape = shape
    self._gain = gain

XavierNormalTensorInitializer

XavierNormalTensorInitializer(shape, gain=1.0)

Bases: BaseTensorInitializer

Source code in src/formed/integrations/torch/initializers.py
48
49
50
def __init__(self, shape: Sequence[int], gain: float = 1.0):
    self._shape = shape
    self._gain = gain

KaimingUniformTensorInitializer

KaimingUniformTensorInitializer(
    shape, a=0, mode="fan_in", nonlinearity="leaky_relu"
)

Bases: BaseTensorInitializer

Source code in src/formed/integrations/torch/initializers.py
60
61
62
63
64
65
66
67
68
69
70
def __init__(
    self,
    shape: Sequence[int],
    a: float = 0,
    mode: torch.nn.init._FanMode = "fan_in",
    nonlinearity: torch.nn.init._NonlinearityType = "leaky_relu",
):
    self._shape = shape
    self._a = a
    self._mode: torch.nn.init._FanMode = mode
    self._nonlinearity: torch.nn.init._NonlinearityType = nonlinearity

KaimingNormalTensorInitializer

KaimingNormalTensorInitializer(
    shape, a=0, mode="fan_in", nonlinearity="leaky_relu"
)

Bases: BaseTensorInitializer

Source code in src/formed/integrations/torch/initializers.py
80
81
82
83
84
85
86
87
88
89
90
def __init__(
    self,
    shape: Sequence[int],
    a: float = 0,
    mode: torch.nn.init._FanMode = "fan_in",
    nonlinearity: torch.nn.init._NonlinearityType = "leaky_relu",
):
    self._shape = shape
    self._a = a
    self._mode: torch.nn.init._FanMode = mode
    self._nonlinearity: torch.nn.init._NonlinearityType = nonlinearity

OrthogonalTensorInitializer

OrthogonalTensorInitializer(shape, gain=1.0)

Bases: BaseTensorInitializer

Source code in src/formed/integrations/torch/initializers.py
100
101
102
def __init__(self, shape: Sequence[int], gain: float = 1.0):
    self._shape = shape
    self._gain = gain

SparseTensorInitializer

SparseTensorInitializer(shape, sparsity=0.1, std=0.01)

Bases: BaseTensorInitializer

Source code in src/formed/integrations/torch/initializers.py
112
113
114
115
def __init__(self, shape: Sequence[int], sparsity: float = 0.1, std: float = 0.01):
    self._shape = shape
    self._sparsity = sparsity
    self._std = std

ZerosTensorInitializer

ZerosTensorInitializer(shape)

Bases: BaseTensorInitializer

Source code in src/formed/integrations/torch/initializers.py
125
126
def __init__(self, shape: Sequence[int]):
    self._shape = shape

OnesTensorInitializer

OnesTensorInitializer(shape)

Bases: BaseTensorInitializer

Source code in src/formed/integrations/torch/initializers.py
134
135
def __init__(self, shape: Sequence[int]):
    self._shape = shape

formed.integrations.torch.model

Base model abstraction for PyTorch models.

This module provides the base class for all PyTorch models in the framework, integrating torch.nn.Module with the registrable pattern for configuration-based model instantiation.

Key Features
  • Integration with PyTorch Module system
  • Registrable pattern for configuration-based instantiation
  • Generic type support for inputs, outputs, and parameters
  • Compatible with TorchTrainer for end-to-end training

Examples:

>>> from formed.integrations.torch import BaseTorchModel
>>> import torch
>>> import torch.nn as nn
>>>
>>> @BaseTorchModel.register("my_model")
... class MyModel(BaseTorchModel[dict, torch.Tensor, None]):
...     def __init__(self, hidden_dim: int):
...         super().__init__()
...         self.linear = nn.Linear(10, hidden_dim)
...
...     def forward(self, inputs: dict, params: None = None) -> torch.Tensor:
...         return self.linear(inputs["features"])

BaseTorchModel

Bases: Module, Registrable, Generic[ModelInputT, ModelOutputT, ModelParamsT]

Base class for all PyTorch models in the framework.

This class combines PyTorch's nn.Module with the registrable pattern, allowing models to be instantiated from configuration files and seamlessly integrated with the training infrastructure.

CLASS TYPE PARAMETER DESCRIPTION
ModelInputT

Type of input data to the model.

ModelOutputT

Type of model output.

ModelParamsT

Type of additional parameters (typically None or a dataclass).

Note

Subclasses should implement forward() to define the forward pass. Models are automatically compatible with TorchTrainer when registered.

forward

forward(inputs, params=None)

Forward pass of the model.

PARAMETER DESCRIPTION
inputs

Input data to the model.

TYPE: ModelInputT

params

Optional additional parameters for the forward pass.

TYPE: ModelParamsT | None DEFAULT: None

RETURNS DESCRIPTION
ModelOutputT

Model output.

RAISES DESCRIPTION
NotImplementedError

This method must be implemented by subclasses.

Source code in src/formed/integrations/torch/model.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def forward(self, inputs: ModelInputT, params: ModelParamsT | None = None) -> ModelOutputT:
    """Forward pass of the model.

    Args:
        inputs: Input data to the model.
        params: Optional additional parameters for the forward pass.

    Returns:
        Model output.

    Raises:
        NotImplementedError: This method must be implemented by subclasses.

    """
    raise NotImplementedError()

formed.integrations.torch.schedulers

Learning rate schedulers for PyTorch models.

This module provides custom learning rate schedulers that extend PyTorch's standard scheduler functionality, including cosine annealing with warm restarts and warmup phases.

Available Schedulers
  • CosineLRScheduler: Cosine annealing with optional restarts and warmup
Features
  • Cosine decay with configurable cycle length
  • Warm restarts with cycle multiplier
  • Learning rate warmup phase
  • Cycle-based decay multiplier
  • Compatible with Colt registration system

Examples:

>>> from formed.integrations.torch.schedulers import CosineLRScheduler
>>>
>>> scheduler = CosineLRScheduler(
...     optimizer,
...     t_initial=100,
...     lr_min=1e-6,
...     warmup_t=5,
...     warmup_lr_init=1e-5
... )
>>> for epoch in range(num_epochs):
...     train(...)
...     scheduler.step(epoch + 1)

CosineLRScheduler

CosineLRScheduler(
    optimizer,
    t_initial,
    lr_min=0.0,
    cycle_mul=1.0,
    cycle_decay=1.0,
    cycle_limit=1,
    warmup_t=0,
    warmup_lr_init=0.0,
    warmup_prefix=False,
    t_in_epochs=True,
    last_epoch=-1,
)

Bases: LRScheduler

Cosine annealing learning rate scheduler with warm restarts.

Implements the SGDR (Stochastic Gradient Descent with Warm Restarts) algorithm described in https://arxiv.org/abs/1608.03983.

This scheduler decreases the learning rate following a cosine curve, optionally restarting the schedule multiple times during training. It also supports a warmup phase at the beginning.

PARAMETER DESCRIPTION
optimizer

Wrapped optimizer.

TYPE: Optimizer

t_initial

Number of iterations/epochs for the first cycle.

TYPE: int

lr_min

Minimum learning rate. Default: 0.

TYPE: float DEFAULT: 0.0

cycle_mul

Multiplier for cycle length after each restart. Default: 1.0.

TYPE: float DEFAULT: 1.0

cycle_decay

Decay factor applied to learning rate at each restart. Default: 1.0.

TYPE: float DEFAULT: 1.0

cycle_limit

Maximum number of restart cycles (0 means no limit). Default: 1.

TYPE: int DEFAULT: 1

warmup_t

Number of warmup iterations/epochs. Default: 0.

TYPE: int DEFAULT: 0

warmup_lr_init

Initial learning rate during warmup. Default: 0.

TYPE: float DEFAULT: 0.0

warmup_prefix

If True, warmup iterations don't count toward t_initial. Default: False.

TYPE: bool DEFAULT: False

t_in_epochs

If True, t values are in epochs; otherwise in iterations. Default: True.

TYPE: bool DEFAULT: True

last_epoch

The index of last epoch. Default: -1.

TYPE: int DEFAULT: -1

Examples:

>>> # Create scheduler with 100 epoch cycles and 5 epoch warmup
>>> scheduler = CosineLRScheduler(
...     optimizer,
...     t_initial=100,
...     lr_min=1e-6,
...     cycle_mul=2.0,  # Each cycle is 2x longer
...     warmup_t=5,
...     warmup_lr_init=1e-5
... )
>>>
>>> # Update learning rate each epoch
>>> for epoch in range(num_epochs):
...     train_one_epoch(...)
...     scheduler.step(epoch + 1)
Source code in src/formed/integrations/torch/schedulers.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def __init__(
    self,
    optimizer: optim.Optimizer,
    t_initial: int,
    lr_min: float = 0.0,
    cycle_mul: float = 1.0,
    cycle_decay: float = 1.0,
    cycle_limit: int = 1,
    warmup_t: int = 0,
    warmup_lr_init: float = 0.0,
    warmup_prefix: bool = False,
    t_in_epochs: bool = True,
    last_epoch: int = -1,
) -> None:
    assert t_initial > 0, "t_initial must be positive"
    assert lr_min >= 0, "lr_min must be non-negative"

    self.t_initial = t_initial
    self.lr_min = lr_min
    self.cycle_mul = cycle_mul
    self.cycle_decay = cycle_decay
    self.cycle_limit = cycle_limit
    self.warmup_t = warmup_t
    self.warmup_lr_init = warmup_lr_init
    self.warmup_prefix = warmup_prefix
    self.t_in_epochs = t_in_epochs

    # Store base learning rates
    self.base_lrs = [group["lr"] for group in optimizer.param_groups]

    # Initialize warmup steps
    if self.warmup_t:
        self.warmup_steps = [(base_lr - warmup_lr_init) / self.warmup_t for base_lr in self.base_lrs]
    else:
        self.warmup_steps = [1.0 for _ in self.base_lrs]

    super().__init__(optimizer, last_epoch)

t_initial instance-attribute

t_initial = t_initial

lr_min instance-attribute

lr_min = lr_min

cycle_mul instance-attribute

cycle_mul = cycle_mul

cycle_decay instance-attribute

cycle_decay = cycle_decay

cycle_limit instance-attribute

cycle_limit = cycle_limit

warmup_t instance-attribute

warmup_t = warmup_t

warmup_lr_init instance-attribute

warmup_lr_init = warmup_lr_init

warmup_prefix instance-attribute

warmup_prefix = warmup_prefix

t_in_epochs instance-attribute

t_in_epochs = t_in_epochs

base_lrs instance-attribute

base_lrs = [(group['lr']) for group in (param_groups)]

warmup_steps instance-attribute

warmup_steps = [
    ((base_lr - warmup_lr_init) / warmup_t)
    for base_lr in (base_lrs)
]

get_lr

get_lr()

Compute learning rate at the current step.

RETURNS DESCRIPTION
list[float | Tensor]

List of learning rates for each parameter group.

Source code in src/formed/integrations/torch/schedulers.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def get_lr(self) -> list[float | torch.Tensor]:
    """Compute learning rate at the current step.

    Returns:
        List of learning rates for each parameter group.

    """
    # Current timestep (starts from 0 after first step())
    t = self.last_epoch

    # Warmup phase: linearly interpolate from warmup_lr_init to base_lr
    if t < self.warmup_t:
        lrs = [self.warmup_lr_init + (t + 1) * step for step in self.warmup_steps]
        return lrs

    # Adjust t if warmup is a prefix
    if self.warmup_prefix:
        t = t - self.warmup_t

    # Determine current cycle
    if self.cycle_mul == 1.0:
        # Simple case: equal cycles
        cycle = t // self.t_initial
        t_curr = t % self.t_initial
        t_i = self.t_initial
    else:
        # Geometric progression of cycle lengths
        # Find which cycle we're in using logarithmic calculation
        cycle = int(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
        # Compute cumulative time up to current cycle
        t_prev = self.t_initial * (1 - self.cycle_mul**cycle) / (1 - self.cycle_mul)
        t_curr = t - t_prev
        t_i = self.t_initial * (self.cycle_mul**cycle)

    # Apply cycle limit
    if self.cycle_limit > 0 and cycle >= self.cycle_limit:
        return [self.lr_min for _ in self.base_lrs]

    # Compute cycle decay
    cycle_decay = self.cycle_decay**cycle

    # Cosine annealing
    lrs = [
        self.lr_min + (base_lr - self.lr_min) * cycle_decay * 0.5 * (1 + math.cos(math.pi * t_curr / t_i))
        for base_lr in self.base_lrs
    ]

    return lrs

get_cycle_length

get_cycle_length(cycles=0)

Calculate total number of iterations for a given number of cycles.

PARAMETER DESCRIPTION
cycles

Number of cycles (0 means current cycle).

TYPE: int DEFAULT: 0

RETURNS DESCRIPTION
int

Total number of iterations.

Source code in src/formed/integrations/torch/schedulers.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def get_cycle_length(self, cycles: int = 0) -> int:
    """Calculate total number of iterations for a given number of cycles.

    Args:
        cycles: Number of cycles (`0` means current cycle).

    Returns:
        Total number of iterations.

    """
    if cycles <= 0:
        cycles = self.cycle_limit if self.cycle_limit > 0 else 1

    if self.cycle_mul == 1.0:
        length = self.t_initial * cycles
    else:
        length = int(self.t_initial * (1 - self.cycle_mul**cycles) / (1 - self.cycle_mul))

    if self.warmup_prefix:
        length += self.warmup_t

    return length

state_dict

state_dict()

Return the state of the scheduler as a dict.

RETURNS DESCRIPTION
dict[str, Any]

Dictionary containing scheduler state.

Source code in src/formed/integrations/torch/schedulers.py
191
192
193
194
195
196
197
198
199
200
201
202
203
def state_dict(self) -> dict[str, Any]:
    """Return the state of the scheduler as a dict.

    Returns:
        Dictionary containing scheduler state.

    """
    state = {
        key: value
        for key, value in self.__dict__.items()
        if key not in ("optimizer", "_get_lr_called_within_step", "_step_count")
    }
    return state

load_state_dict

load_state_dict(state_dict)

Load the scheduler state.

PARAMETER DESCRIPTION
state_dict

Scheduler state dict.

TYPE: dict[str, Any]

Source code in src/formed/integrations/torch/schedulers.py
205
206
207
208
209
210
211
212
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
    """Load the scheduler state.

    Args:
        state_dict: Scheduler state dict.

    """
    self.__dict__.update(state_dict)

formed.integrations.torch.utils

Utility functions for PyTorch integration.

PoolingMethod module-attribute

PoolingMethod = Literal[
    "mean", "max", "min", "sum", "first", "last", "hier"
]

set_random_seed

set_random_seed(seed)

Set random seed for reproducibility across torch, numpy, and random.

PARAMETER DESCRIPTION
seed

Random seed value.

TYPE: int

Source code in src/formed/integrations/torch/utils.py
17
18
19
20
21
22
23
24
25
26
27
def set_random_seed(seed: int) -> None:
    """Set random seed for reproducibility across torch, numpy, and random.

    Args:
        seed: Random seed value.
    """
    torch.manual_seed(seed)
    numpy.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

ensure_torch_tensor

ensure_torch_tensor(x, dtype=None, device=None)

Convert array-like objects to PyTorch tensors.

This function converts various array-like objects (numpy arrays, lists, etc.) to PyTorch tensors. If the input is already a tensor, it returns it with the appropriate dtype and device.

The device can be specified explicitly via the device parameter, or it will be taken from the context set by use_device(). If neither is provided and the input is not already a tensor, the tensor will be created on CPU.

PARAMETER DESCRIPTION
x

Input data (tensor, numpy array, list, etc.)

TYPE: TensorCompatible

dtype

Optional dtype for the output tensor.

TYPE: Optional[dtype] DEFAULT: None

device

Optional device for the output tensor. If None, uses the device from context (set by use_device()). If the input is already a tensor, its device is preserved unless explicitly specified.

TYPE: Optional[Union[device, str]] DEFAULT: None

RETURNS DESCRIPTION
Tensor

PyTorch tensor on the specified device with the specified dtype.

Examples:

>>> import numpy as np
>>> from formed.integrations.torch import ensure_torch_tensor, use_device
>>> arr = np.array([1, 2, 3])
>>>
>>> # Without context
>>> tensor = ensure_torch_tensor(arr)
>>> tensor.device
device(type='cpu')
>>>
>>> # With context
>>> with use_device("cuda:0"):
...     tensor = ensure_torch_tensor(arr)
...     print(tensor.device)
cuda:0
Source code in src/formed/integrations/torch/utils.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def ensure_torch_tensor(
    x: TensorCompatible,
    dtype: Optional[torch.dtype] = None,
    device: Optional[Union[torch.device, str]] = None,
) -> torch.Tensor:
    """Convert array-like objects to PyTorch tensors.

    This function converts various array-like objects (numpy arrays, lists, etc.)
    to PyTorch tensors. If the input is already a tensor, it returns it with the
    appropriate dtype and device.

    The device can be specified explicitly via the `device` parameter, or it will
    be taken from the context set by `use_device()`. If neither is provided and
    the input is not already a tensor, the tensor will be created on CPU.

    Args:
        x: Input data (tensor, numpy array, list, etc.)
        dtype: Optional dtype for the output tensor.
        device: Optional device for the output tensor. If None, uses the device
            from context (set by `use_device()`). If the input is already a tensor,
            its device is preserved unless explicitly specified.

    Returns:
        PyTorch tensor on the specified device with the specified dtype.

    Examples:
        >>> import numpy as np
        >>> from formed.integrations.torch import ensure_torch_tensor, use_device
        >>> arr = np.array([1, 2, 3])
        >>>
        >>> # Without context
        >>> tensor = ensure_torch_tensor(arr)
        >>> tensor.device
        device(type='cpu')
        >>>
        >>> # With context
        >>> with use_device("cuda:0"):
        ...     tensor = ensure_torch_tensor(arr)
        ...     print(tensor.device)
        cuda:0
    """
    # Determine target device
    if device is None:
        device = get_device()

    if isinstance(x, torch.Tensor):
        # If already a tensor, convert dtype/device as needed
        needs_dtype_conversion = dtype is not None and x.dtype != dtype
        needs_device_conversion = device is not None and x.device != torch.device(device)

        if needs_dtype_conversion or needs_device_conversion:
            kwargs = {}
            if dtype is not None:
                kwargs["dtype"] = dtype
            if device is not None:
                kwargs["device"] = device
            return x.to(**kwargs)
        return x

    # Convert numpy arrays, handling float64 -> float32 conversion
    import numpy as np

    if isinstance(x, np.ndarray):
        if dtype is None and x.dtype == np.float64:
            # Default: convert float64 to float32 for PyTorch
            dtype = torch.float32
        tensor = torch.from_numpy(x)
        if dtype is not None:
            tensor = tensor.to(dtype)
        if device is not None:
            tensor = tensor.to(device)
        return tensor

    # Convert other array-like objects
    tensor = torch.as_tensor(x)
    if dtype is not None:
        tensor = tensor.to(dtype)
    if device is not None:
        tensor = tensor.to(device)
    return tensor

move_to_device

move_to_device(inputs, device)

Move tensor inputs to the appropriate device.

This function only moves existing torch.Tensor objects to the target device. Other types (numpy arrays, primitives, etc.) are left unchanged. Users should explicitly convert numpy arrays to tensors in their model's forward method using ensure_torch_tensor().

Source code in src/formed/integrations/torch/utils.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def move_to_device(inputs: ModelInputT, device: Optional[Union[torch.device, str]]) -> ModelInputT:
    """Move tensor inputs to the appropriate device.

    This function only moves existing torch.Tensor objects to the target device.
    Other types (numpy arrays, primitives, etc.) are left unchanged.
    Users should explicitly convert numpy arrays to tensors in their model's
    forward method using `ensure_torch_tensor()`.
    """
    from typing import Any

    visited: set[int] = set()

    def _move(obj: Any) -> Any:
        # Handle tensors - move to device
        if isinstance(obj, torch.Tensor):
            return obj.to(device)

        # Handle primitives and None - no conversion needed
        if obj is None or isinstance(obj, (int, float, str, bool, type)):
            return obj

        # Check if already visited to avoid infinite recursion
        obj_id = id(obj)
        if obj_id in visited:
            return obj
        visited.add(obj_id)

        # Handle dict
        if isinstance(obj, dict):
            return {k: _move(v) for k, v in obj.items()}

        # Handle list/tuple
        if isinstance(obj, (list, tuple)):
            return type(obj)(_move(x) for x in obj)

        # Handle objects with __dict__ (but not built-in types)
        if hasattr(obj, "__dict__") and not isinstance(obj, type):
            try:
                for key, value in list(obj.__dict__.items()):
                    # Skip dunder attributes
                    if not key.startswith("__"):
                        setattr(obj, key, _move(value))
            except (TypeError, AttributeError):
                # Skip objects that don't allow attribute modification
                pass
            return obj

        return obj

    return cast(ModelInputT, _move(inputs))

determine_ndim

determine_ndim(first, *args)
Source code in src/formed/integrations/torch/utils.py
164
165
166
167
168
169
170
171
172
173
174
175
176
def determine_ndim(
    first: int,
    *args: Optional[Union[int, Callable[[int], int]]],
) -> int:
    output_dim = first
    for arg in args:
        if arg is None:
            continue
        if callable(arg):
            output_dim = arg(output_dim)
        else:
            output_dim = arg
    return output_dim

masked_pool

masked_pool(
    inputs,
    *,
    mask=None,
    pooling="mean",
    normalize=False,
    window_size=None,
)

Apply masked pooling over the sequence dimension.

PARAMETER DESCRIPTION
inputs

Input tensor of shape (batch_size, seq_len, feature_dim).

TYPE: Tensor

mask

Mask tensor of shape (batch_size, seq_len). True/1 indicates valid positions.

TYPE: Optional[Tensor] DEFAULT: None

pooling

Pooling method or sequence of methods.

TYPE: Union[PoolingMethod, Sequence[PoolingMethod]] DEFAULT: 'mean'

normalize

Whether to L2-normalize before pooling.

TYPE: bool DEFAULT: False

window_size

Window size for hierarchical pooling (required if pooling="hier").

TYPE: Optional[int] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Pooled tensor of shape (batch_size, feature_dim * num_pooling_methods).

Source code in src/formed/integrations/torch/utils.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
def masked_pool(
    inputs: torch.Tensor,
    *,
    mask: Optional[torch.Tensor] = None,
    pooling: Union[PoolingMethod, Sequence[PoolingMethod]] = "mean",
    normalize: bool = False,
    window_size: Optional[int] = None,
) -> torch.Tensor:
    """Apply masked pooling over the sequence dimension.

    Args:
        inputs: Input tensor of shape `(batch_size, seq_len, feature_dim)`.
        mask: Mask tensor of shape `(batch_size, seq_len)`. `True`/`1` indicates valid positions.
        pooling: Pooling method or sequence of methods.
        normalize: Whether to L2-normalize before pooling.
        window_size: Window size for hierarchical pooling (required if `pooling="hier"`).

    Returns:
        Pooled tensor of shape `(batch_size, feature_dim * num_pooling_methods)`.

    """
    if normalize:
        inputs = F.normalize(inputs, p=2, dim=-1)

    if mask is None:
        mask = torch.ones(inputs.shape[:-1], dtype=torch.bool, device=inputs.device)

    # Convert mask to boolean if needed
    if mask.dtype != torch.bool:
        mask = mask.bool()

    pooling_methods = [pooling] if isinstance(pooling, str) else list(pooling)
    results = []

    for method in pooling_methods:
        if method == "mean":
            # Masked mean
            masked_inputs = inputs * mask.unsqueeze(-1)
            pooled = masked_inputs.sum(dim=1) / mask.sum(dim=1, keepdim=True).clamp(min=1)
        elif method == "max":
            # Masked max
            masked_inputs = inputs.masked_fill(~mask.unsqueeze(-1), float("-inf"))
            pooled, _ = masked_inputs.max(dim=1)
        elif method == "min":
            # Masked min
            masked_inputs = inputs.masked_fill(~mask.unsqueeze(-1), float("inf"))
            pooled, _ = masked_inputs.min(dim=1)
        elif method == "sum":
            # Masked sum
            masked_inputs = inputs * mask.unsqueeze(-1)
            pooled = masked_inputs.sum(dim=1)
        elif method == "first":
            # First token
            pooled = inputs[:, 0, :]
        elif method == "last":
            # Last valid token
            # Find the index of the last valid token for each sequence
            lengths = mask.sum(dim=1).clamp(min=1) - 1  # -1 because indices are 0-based
            batch_indices = torch.arange(inputs.size(0), device=inputs.device)
            pooled = inputs[batch_indices, lengths.long()]
        elif method == "hier":
            # Hierarchical pooling with sliding window
            if window_size is None:
                raise ValueError("window_size must be specified for hierarchical pooling")

            batch_size = inputs.size(0)
            feature_dim = inputs.size(-1)
            pooled_list = []

            for i in range(batch_size):
                # Get valid vectors for this sequence
                valid_vectors = inputs[i][mask[i]]
                seq_len = valid_vectors.size(0)

                if seq_len < window_size:
                    # If sequence is shorter than window, just take mean
                    pooled_list.append(valid_vectors.mean(dim=0))
                else:
                    # Slide window and compute max of means
                    output = torch.full((feature_dim,), float("-inf"), device=inputs.device)
                    for offset in range(seq_len - window_size + 1):
                        window = valid_vectors[offset : offset + window_size]
                        window_mean = window.mean(dim=0)
                        output = torch.maximum(output, window_mean)
                    pooled_list.append(output)

            pooled = torch.stack(pooled_list)
        else:
            raise ValueError(f"Unknown pooling method: {method}")

        results.append(pooled)

    return torch.cat(results, dim=-1) if len(results) > 1 else results[0]

info_value_of_dtype

info_value_of_dtype(dtype)

Returns the finfo or iinfo object of a given PyTorch data type. Does not allow torch.bool.

Source code in src/formed/integrations/torch/utils.py
277
278
279
280
281
282
283
284
def info_value_of_dtype(dtype: torch.dtype) -> Union[torch.finfo, torch.iinfo]:
    """Returns the `finfo` or `iinfo` object of a given PyTorch data type. Does not allow torch.bool."""
    if dtype == torch.bool:
        raise TypeError("Does not support torch.bool")
    elif dtype.is_floating_point:
        return torch.finfo(dtype)
    else:
        return torch.iinfo(dtype)

min_value_of_dtype

min_value_of_dtype(dtype)

Returns the minimum value of a given PyTorch data type. Does not allow torch.bool.

Source code in src/formed/integrations/torch/utils.py
287
288
289
def min_value_of_dtype(dtype: torch.dtype) -> Union[float, int]:
    """Returns the minimum value of a given PyTorch data type. Does not allow torch.bool."""
    return info_value_of_dtype(dtype).min

max_value_of_dtype

max_value_of_dtype(dtype)

Returns the maximum value of a given PyTorch data type. Does not allow torch.bool.

Source code in src/formed/integrations/torch/utils.py
292
293
294
def max_value_of_dtype(dtype: torch.dtype) -> Union[float, int]:
    """Returns the maximum value of a given PyTorch data type. Does not allow torch.bool."""
    return info_value_of_dtype(dtype).max

tiny_value_of_dtype

tiny_value_of_dtype(dtype)

Returns a moderately tiny value for a given PyTorch data type that is used to avoid numerical issues such as division by zero. This is different from info_value_of_dtype(dtype).tiny because it causes some NaN bugs. Only supports floating point dtypes.

Source code in src/formed/integrations/torch/utils.py
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
def tiny_value_of_dtype(dtype: torch.dtype) -> float | int:
    """
    Returns a moderately tiny value for a given PyTorch data type that is used to avoid numerical
    issues such as division by zero.
    This is different from `info_value_of_dtype(dtype).tiny` because it causes some NaN bugs.
    Only supports floating point dtypes.
    """
    if not dtype.is_floating_point:
        raise TypeError("Only supports floating point dtypes.")
    if dtype in (torch.float, torch.double):
        return 1e-13
    elif dtype == torch.half:
        return 1e-4
    else:
        raise TypeError("Does not support dtype " + str(dtype))

masked_mean

masked_mean(vector, mask, dim, keepdim=False)
Source code in src/formed/integrations/torch/utils.py
314
315
316
317
318
319
320
321
322
323
def masked_mean(
    vector: _TensorT,
    mask: torch.Tensor,
    dim: int,
    keepdim: bool = False,
) -> _TensorT:
    replaced_vector = vector.masked_fill(~mask, 0.0)
    value_sum = torch.sum(replaced_vector, dim=dim, keepdim=keepdim)
    value_count = torch.sum(mask, dim=dim, keepdim=keepdim)
    return cast(_TensorT, value_sum / value_count.float().clamp(min=tiny_value_of_dtype(torch.float)))

masked_softmax

masked_softmax(
    vector, mask, dim=-1, memory_efficient=False
)
Source code in src/formed/integrations/torch/utils.py
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
def masked_softmax(
    vector: _TensorT,
    mask: torch.Tensor,
    dim: int = -1,
    memory_efficient: bool = False,
) -> _TensorT:
    while mask.dim() < vector.dim():
        mask = mask.unsqueeze(1)
    if not memory_efficient:
        # To limit numerical errors from large vector elements outside the mask, we zero these out.
        result = torch.nn.functional.softmax(vector * mask, dim=dim)
        result = result * mask
        result = result / (result.sum(dim=dim, keepdim=True) + tiny_value_of_dtype(result.dtype))
    else:
        masked_vector = vector.masked_fill(~mask, min_value_of_dtype(vector.dtype))
        result = torch.nn.functional.softmax(masked_vector, dim=dim)
    return cast(_TensorT, result)

formed.integrations.torch.modules.embedders

Text embedding modules for PyTorch models.

This module provides embedders that convert tokenized text into dense vector representations. Embedders handle various text representations including surface forms, part-of-speech tags, and character sequences.

Key Components
  • BaseEmbedder: Abstract base class for all embedders
  • TokenEmbedder: Embeds token ID sequences into dense vectors
  • AnalyzedTextEmbedder: Combines multiple embedding types (surface, POS, chars)
Features
  • Support for nested token sequences (e.g., word -> character)
  • Automatic masking and padding handling
  • Configurable vectorization for character-level embeddings
  • Concatenation of multiple embedding types

Examples:

>>> from formed.integrations.torch.modules import TokenEmbedder, AnalyzedTextEmbedder
>>> import torch.nn as nn
>>>
>>> # Simple token embedder
>>> embedder = TokenEmbedder(
...     vocab_size=10000,
...     embedding_dim=128
... )
>>>
>>> # Multi-feature embedder
>>> embedder = AnalyzedTextEmbedder(
...     surface=TokenEmbedder(vocab_size=10000, embedding_dim=128),
...     postag=TokenEmbedder(vocab_size=50, embedding_dim=32)
... )

SurfaceBatchT module-attribute

SurfaceBatchT = TypeVar(
    "SurfaceBatchT", bound="IIDSequenceBatch", default=Any
)

PostagBatchT module-attribute

PostagBatchT = TypeVar(
    "PostagBatchT",
    bound=Union["IIDSequenceBatch", None],
    default=Any,
)

CharacterBatchT module-attribute

CharacterBatchT = TypeVar(
    "CharacterBatchT",
    bound=Union["IIDSequenceBatch", None],
    default=Any,
)

TokenVectorBatchT module-attribute

TokenVectorBatchT = TypeVar(
    "TokenVectorBatchT",
    bound=Union["IVariableTensorBatch", None],
    default=Any,
)

IVariableTensorBatch

Bases: Protocol[TensorCompatibleT]

Protocol for variable-length tensor batches.

ATTRIBUTE DESCRIPTION
tensor

Tensor of shape (batch_size, seq_len, feature_dim).

TYPE: TensorCompatibleT

mask

Attention mask of shape (batch_size, seq_len).

TYPE: TensorCompatibleT

tensor instance-attribute

tensor

mask instance-attribute

mask

IAnalyzedTextBatch

Bases: Protocol[SurfaceBatchT, PostagBatchT, CharacterBatchT, TokenVectorBatchT]

Protocol for analyzed text batches with multiple linguistic features.

ATTRIBUTE DESCRIPTION
surfaces

Surface form token IDs.

TYPE: SurfaceBatchT

postags

Part-of-speech tag IDs (optional).

TYPE: PostagBatchT

characters

Character sequence IDs (optional).

TYPE: CharacterBatchT

token_vectors

Token-level dense vectors (optional).

TYPE: TokenVectorBatchT

surfaces instance-attribute

surfaces

postags instance-attribute

postags

characters instance-attribute

characters

token_vectors instance-attribute

token_vectors

EmbedderOutput

Bases: NamedTuple

Output from an embedder.

ATTRIBUTE DESCRIPTION
embeddings

Dense embeddings of shape (batch_size, seq_len, embedding_dim).

TYPE: Tensor

mask

Attention mask of shape (batch_size, seq_len).

TYPE: Tensor

embeddings instance-attribute

embeddings

mask instance-attribute

mask

BaseEmbedder

Bases: Module, Registrable, Generic[_TextBatchT], ABC

Abstract base class for text embedders.

Embedders convert tokenized text into dense vector representations. They output both embeddings and attention masks.

CLASS TYPE PARAMETER DESCRIPTION
_TextBatchT

Type of input batch (e.g., IIDSequenceBatch, IAnalyzedTextBatch).

forward abstractmethod

forward(inputs)

Embed input tokens into dense vectors.

PARAMETER DESCRIPTION
inputs

Batch of tokenized text.

TYPE: _TextBatchT

RETURNS DESCRIPTION
EmbedderOutput

EmbedderOutput containing embeddings and mask.

Source code in src/formed/integrations/torch/modules/embedders.py
132
133
134
135
136
137
138
139
140
141
142
143
@abc.abstractmethod
def forward(self, inputs: _TextBatchT) -> EmbedderOutput:
    """Embed input tokens into dense vectors.

    Args:
        inputs: Batch of tokenized text.

    Returns:
        EmbedderOutput containing embeddings and mask.

    """
    raise NotImplementedError

get_output_dim abstractmethod

get_output_dim()

Get the output embedding dimension.

RETURNS DESCRIPTION
int

Embedding dimension.

Source code in src/formed/integrations/torch/modules/embedders.py
145
146
147
148
149
150
151
152
153
@abc.abstractmethod
def get_output_dim(self) -> int:
    """Get the output embedding dimension.

    Returns:
        Embedding dimension.

    """
    raise NotImplementedError

PassThroughEmbedder

Bases: BaseEmbedder[IVariableTensorBatch[TensorCompatibleT]]

Embedder that passes through input tensors unchanged.

This embedder is useful when the input tensors are already in the desired embedding format. It simply returns the input tensors and their masks.

Examples:

>>> from formed.integrations.torch.modules import PassThroughEmbedder
>>>
>>> embedder = PassThroughEmbedder()
>>> output = embedder(variable_tensor_batch)
>>> assert torch.equal(output.embeddings, variable_tensor_batch.tensor)
>>> assert torch.equal(output.mask, variable_tensor_batch.mask)

forward

forward(inputs)
Source code in src/formed/integrations/torch/modules/embedders.py
176
177
178
179
180
181
182
def forward(
    self,
    inputs: IVariableTensorBatch[TensorCompatibleT],
) -> EmbedderOutput:
    tensor = ensure_torch_tensor(inputs.tensor)
    mask = ensure_torch_tensor(inputs.mask).bool()
    return EmbedderOutput(embeddings=tensor, mask=mask)

get_output_dim

get_output_dim()
Source code in src/formed/integrations/torch/modules/embedders.py
184
185
def get_output_dim(self) -> int:
    raise NotImplementedError("PassThroughEmbedder does not have a fixed output dimension.")

TokenEmbedder

TokenEmbedder(
    initializer,
    *,
    padding_idx=0,
    freeze=False,
    vectorizer=None,
)

Bases: BaseEmbedder['IIDSequenceBatch']

Embedder for token ID sequences.

This embedder converts token IDs into dense embeddings using a learned embedding matrix. It supports both 2D (batch_size, seq_len) and 3D (batch_size, seq_len, char_len) token ID tensors.

For 3D inputs (e.g., character-level tokens within words), the embedder can either average the embeddings or apply a custom vectorizer.

PARAMETER DESCRIPTION
initializer

Tensor initializer or callable that returns the embedding tensor.

TYPE: BaseTensorInitializer | Callable[[], TensorCompatible]

padding_idx

Index of the padding token (default: 0).

TYPE: int DEFAULT: 0

vectorizer

Optional vectorizer for 3D inputs (character sequences).

TYPE: Optional[BaseSequenceVectorizer] DEFAULT: None

Examples:

>>> # Simple word embeddings
>>> embedder = TokenEmbedder(vocab_size=10000, embedding_dim=128)
>>> output = embedder(word_ids_batch)
>>>
>>> # Character-level embeddings with pooling
>>> from formed.integrations.torch.modules import BagOfEmbeddingsSequenceVectorizer
>>> embedder = TokenEmbedder(
...     vocab_size=256,
...     embedding_dim=32,
...     vectorizer=BagOfEmbeddingsSequenceVectorizer(pooling="max")
... )
Source code in src/formed/integrations/torch/modules/embedders.py
219
220
221
222
223
224
225
226
227
228
229
230
231
def __init__(
    self,
    initializer: BaseTensorInitializer | Callable[[], TensorCompatible],
    *,
    padding_idx: int = 0,
    freeze: bool = False,
    vectorizer: Optional[BaseSequenceVectorizer] = None,
) -> None:
    weight = ensure_torch_tensor(initializer())

    super().__init__()
    self._embedding = nn.Embedding.from_pretrained(weight, padding_idx=padding_idx, freeze=freeze)
    self._vectorizer = vectorizer

forward

forward(inputs)
Source code in src/formed/integrations/torch/modules/embedders.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def forward(self, inputs: "IIDSequenceBatch") -> EmbedderOutput:
    token_ids = ensure_torch_tensor(inputs.ids)
    mask = ensure_torch_tensor(inputs.mask).bool()

    nested = False
    if token_ids.ndim > 2:
        if token_ids.ndim != 3:
            raise ValueError("Token ids must be of shape (batch_size, seq_len) or (batch_size, seq_len, char_len)")
        nested = True

    if token_ids.shape != mask.shape:
        raise ValueError(f"Token ids and mask must have the same shape, got {token_ids.shape} and {mask.shape}")

    embeddings = self._embedding(token_ids)

    if nested:
        if self._vectorizer is None:
            # Average pooling over character dimension
            embeddings = (embeddings * mask.unsqueeze(-1)).sum(dim=-2) / mask.sum(dim=-1, keepdim=True).clamp(min=1)
            mask = mask.any(dim=-1)
        else:
            # Flatten batch and sequence dimensions for vectorizer
            batch_size, seq_len, char_len = token_ids.shape
            flat_embeddings = embeddings.view(batch_size * seq_len, char_len, -1)
            flat_mask = mask.view(batch_size * seq_len, char_len)

            # Apply vectorizer
            flat_embeddings = self._vectorizer(flat_embeddings, mask=flat_mask)

            # Reshape back
            embeddings = flat_embeddings.view(batch_size, seq_len, -1)
            mask = mask.any(dim=-1)

    return EmbedderOutput(embeddings=embeddings, mask=mask)

get_output_dim

get_output_dim()
Source code in src/formed/integrations/torch/modules/embedders.py
268
269
def get_output_dim(self) -> int:
    return self._embedding.embedding_dim

PretrainedTransformerEmbedder

PretrainedTransformerEmbedder(
    model,
    auto_class=None,
    subcmodule=None,
    freeze=False,
    eval_mode=False,
    layer_to_use="last",
    gradient_checkpointing=None,
    **kwargs,
)

Bases: BaseEmbedder[IIDSequenceBatch]

Embedder using pretrained transformer models from Hugging Face.

This embedder wraps pretrained transformer models (BERT, RoBERTa, etc.) to extract contextualized embeddings. It uses the last hidden state from the transformer as the embedding representation.

PARAMETER DESCRIPTION
model

Either a model name/path string, PathLike object, or a PreTrainedModel instance. If a string or PathLike, the model will be loaded using transformers auto classes.

TYPE: Union[str, PathLike, PreTrainedModel]

auto_class

The auto class to use for loading the model.

TYPE: str | type[_BaseAutoModelClass] | None DEFAULT: None

subcmodule

Optional submodule path to extract from the loaded model (e.g., "encoder").

TYPE: str | None DEFAULT: None

freeze

If True, freezes all model parameters (no gradient computation).

TYPE: bool DEFAULT: False

**kwargs

Additional keyword arguments passed to the model loader.

TYPE: Any DEFAULT: {}

Examples:

>>> # Load a pretrained BERT model
>>> embedder = PretrainedTransformerEmbedder(
...     model="bert-base-uncased",
...     freeze=True
... )
>>>
>>> # Use a specific auto class
>>> from transformers import AutoModel
>>> embedder = PretrainedTransformerEmbedder(
...     model="roberta-base",
...     auto_class=AutoModel,
...     freeze=False
... )
>>>
>>> # Use an already loaded model
>>> from transformers import AutoModel
>>> model = AutoModel.from_pretrained("bert-base-uncased")
>>> embedder = PretrainedTransformerEmbedder(model=model)
Note

Models are cached using LRU cache by the load_pretrained_transformer utility. When freeze=True, all model parameters have requires_grad=False.

Source code in src/formed/integrations/torch/modules/embedders.py
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
def __init__(
    self,
    model: Union[str, PathLike, "PreTrainedModel"],
    auto_class: str | type["_BaseAutoModelClass"] | None = None,
    subcmodule: str | None = None,
    freeze: bool = False,
    eval_mode: bool = False,
    layer_to_use: Literal["embeddings", "last", "all"] = "last",
    gradient_checkpointing: bool | None = None,
    **kwargs: Any,
) -> None:
    if isinstance(model, (str, PathLike)):
        from formed.integrations.transformers import load_pretrained_transformer

        model = load_pretrained_transformer.__wrapped__(
            model,
            auto_class=auto_class,
            submodule=subcmodule,
            **kwargs,
        )

    super().__init__()

    self._model = model
    self._scalar_mix: ScalarMix | None = None
    self._eval_mode = eval_mode
    self._output_dim = model.config.hidden_size
    self._vocab_size = model.config.vocab_size

    if gradient_checkpointing is not None:
        self._model.config.update({"gradient_checkpointing": gradient_checkpointing})

    if self._eval_mode:
        self._model.eval()

    if freeze:
        for param in self._model.parameters():
            param.requires_grad = False

    if layer_to_use == "all":
        self._scalar_mix = ScalarMix(self._model.config.num_hidden_layers)
        self._model.config.output_hidden_states = True
    elif layer_to_use == "embeddings":
        self._model = PretrainedTransformerEmbedder._Embedding(self._model)

forward

forward(inputs)
Source code in src/formed/integrations/torch/modules/embedders.py
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def forward(self, inputs: IIDSequenceBatch) -> EmbedderOutput:
    input_ids = ensure_torch_tensor(inputs.ids)
    mask = ensure_torch_tensor(inputs.mask)

    if isinstance(self._model, PretrainedTransformerEmbedder._Embedding):
        embeddings = self._model(input_ids)
    else:
        transformer_outputs = self._model(input_ids=input_ids, attention_mask=mask)
        if self._scalar_mix is not None:
            # The hidden states will also include the embedding layer, which we don't
            # include in the scalar mix. Hence the `[1:]` slicing.
            hidden_states = transformer_outputs.hidden_states[1:]
            embeddings = self._scalar_mix(hidden_states)
        else:
            embeddings = transformer_outputs.last_hidden_state
    return EmbedderOutput(embeddings, mask)

get_output_dim

get_output_dim()
Source code in src/formed/integrations/torch/modules/embedders.py
387
388
def get_output_dim(self) -> int:
    return self._output_dim

get_vocab_size

get_vocab_size()
Source code in src/formed/integrations/torch/modules/embedders.py
390
391
def get_vocab_size(self) -> int:
    return self._vocab_size

train

train(mode=True)
Source code in src/formed/integrations/torch/modules/embedders.py
393
394
395
396
397
398
399
400
def train(self, mode: bool = True) -> Self:
    self.training = mode
    for name, module in self.named_children():
        if self._eval_mode and name == "_model":
            module.eval()
        else:
            module.train(mode)
    return self

AnalyzedTextEmbedder

AnalyzedTextEmbedder(
    surface=None,
    postag=None,
    character=None,
    token_vector=None,
)

Bases: BaseEmbedder['IAnalyzedTextBatch']

Embedder for analyzed text with multiple linguistic features.

This embedder combines embeddings from multiple linguistic representations (surface forms, part-of-speech tags, character sequences) by concatenating them along the feature dimension.

PARAMETER DESCRIPTION
surface

Optional embedder for surface form tokens.

TYPE: Optional[BaseEmbedder[IIDSequenceBatch]] DEFAULT: None

postag

Optional embedder for part-of-speech tags.

TYPE: Optional[BaseEmbedder[IIDSequenceBatch]] DEFAULT: None

character

Optional embedder for character sequences.

TYPE: Optional[BaseEmbedder[IIDSequenceBatch]] DEFAULT: None

RAISES DESCRIPTION
ValueError

If all embedders are None (at least one is required).

Examples:

>>> from formed.integrations.torch.modules import (
...     AnalyzedTextEmbedder,
...     TokenEmbedder
... )
>>>
>>> embedder = AnalyzedTextEmbedder(
...     surface=TokenEmbedder(vocab_size=10000, embedding_dim=128),
...     postag=TokenEmbedder(vocab_size=50, embedding_dim=32),
...     character=TokenEmbedder(vocab_size=256, embedding_dim=32)
... )
>>>
>>> # Output dimension is sum of all embedding dimensions (128 + 32 + 32 = 192)
>>> assert embedder.get_output_dim() == 192
Note

All provided embedders share the same mask, which is taken from the last non-None embedder processed.

Source code in src/formed/integrations/torch/modules/embedders.py
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
def __init__(
    self,
    surface: Optional["BaseEmbedder[IIDSequenceBatch]"] = None,
    postag: Optional["BaseEmbedder[IIDSequenceBatch]"] = None,
    character: Optional["BaseEmbedder[IIDSequenceBatch]"] = None,
    token_vector: Optional["BaseEmbedder[IVariableTensorBatch]"] = None,
) -> None:
    super().__init__()
    if all(embedder is None for embedder in (surface, postag, character)):
        raise ValueError("At least one embedder must be provided for AnalyzedTextEmbedder.")

    self._surface = surface
    self._postag = postag
    self._character = character
    self._token_vector = token_vector

forward

forward(inputs)
Source code in src/formed/integrations/torch/modules/embedders.py
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
def forward(self, inputs: "IAnalyzedTextBatch") -> EmbedderOutput:
    embeddings: list[torch.Tensor] = []
    mask: Optional[torch.Tensor] = None

    for embedder, ids in (
        (self._surface, inputs.surfaces),
        (self._postag, inputs.postags),
        (self._character, inputs.characters),
    ):
        if embedder is not None and ids is not None:
            output = embedder(ids)
            embeddings.append(output.embeddings)
            mask = output.mask

    if self._token_vector is not None and inputs.token_vectors is not None:
        output = self._token_vector(inputs.token_vectors)
        embeddings.append(output.embeddings)

    if not embeddings:
        raise ValueError("No embeddings were computed in AnalyzedTextEmbedder.")
    assert mask is not None

    return EmbedderOutput(embeddings=torch.cat(embeddings, dim=-1), mask=mask)

get_output_dim

get_output_dim()
Source code in src/formed/integrations/torch/modules/embedders.py
480
481
482
483
484
485
def get_output_dim(self) -> int:
    return sum(
        embedder.get_output_dim()
        for embedder in (self._surface, self._postag, self._character)
        if embedder is not None
    )

formed.integrations.torch.modules.encoders

Sequence encoding modules for PyTorch models.

This module provides encoders that process sequential data, including RNN-based encoders, positional encoders, and Transformer encoders.

Key Components
  • BaseSequenceEncoder: Abstract base for sequence encoders
  • LSTMSequenceEncoder: LSTM-specific encoder
  • GRUSequenceEncoder: GRU-specific encoder
  • BasePositionalEncoder: Abstract base for positional encoders
  • SinusoidalPositionalEncoder: Sinusoidal positional encoding
  • RotaryPositionalEncoder: Rotary positional encoding (RoPE)
  • LearnablePositionalEncoder: Learnable positional embeddings
  • TransformerEncoder: Transformer-based encoder with configurable masking
Features
  • Bidirectional RNN support
  • Stacked layers with dropout
  • Masked sequence processing
  • Various positional encoding strategies
  • Flexible attention masking

Examples:

>>> from formed.integrations.torch.modules import LSTMSequenceEncoder
>>>
>>> # Bidirectional LSTM encoder
>>> encoder = LSTMSequenceEncoder(
...     input_dim=128,
...     hidden_dim=256,
...     num_layers=2,
...     bidirectional=True,
...     dropout=0.1
... )

BaseSequenceEncoder

Bases: Module, Registrable, ABC

Abstract base class for sequence encoders.

Sequence encoders process sequential data and output encoded representations.

forward abstractmethod

forward(inputs, mask=None)

Encode input sequence.

PARAMETER DESCRIPTION
inputs

Input sequence of shape (batch_size, seq_len, input_dim).

TYPE: Tensor

mask

Optional mask of shape (batch_size, seq_len).

TYPE: Optional[Tensor] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Encoded sequence of shape (batch_size, seq_len, output_dim).

Source code in src/formed/integrations/torch/modules/encoders.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
@abc.abstractmethod
def forward(
    self,
    inputs: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Encode input sequence.

    Args:
        inputs: Input sequence of shape `(batch_size, seq_len, input_dim)`.
        mask: Optional mask of shape `(batch_size, seq_len)`.

    Returns:
        Encoded sequence of shape `(batch_size, seq_len, output_dim)`.

    """
    raise NotImplementedError

get_input_dim abstractmethod

get_input_dim()

Get the expected input dimension.

Source code in src/formed/integrations/torch/modules/encoders.py
75
76
77
78
@abc.abstractmethod
def get_input_dim(self) -> int:
    """Get the expected input dimension."""
    raise NotImplementedError

get_output_dim abstractmethod

get_output_dim()

Get the output dimension.

Source code in src/formed/integrations/torch/modules/encoders.py
80
81
82
83
@abc.abstractmethod
def get_output_dim(self) -> int:
    """Get the output dimension."""
    raise NotImplementedError

LSTMSequenceEncoder

LSTMSequenceEncoder(
    input_dim,
    hidden_dim,
    num_layers=1,
    bidirectional=False,
    dropout=0.0,
    batch_first=True,
)

Bases: BaseSequenceEncoder

LSTM-based sequence encoder.

PARAMETER DESCRIPTION
input_dim

Input dimension.

TYPE: int

hidden_dim

Hidden state dimension.

TYPE: int

num_layers

Number of LSTM layers.

TYPE: int DEFAULT: 1

bidirectional

Whether to use bidirectional LSTM.

TYPE: bool DEFAULT: False

dropout

Dropout rate between layers.

TYPE: float DEFAULT: 0.0

batch_first

Whether input is batch-first (default: True).

TYPE: bool DEFAULT: True

Examples:

>>> encoder = LSTMSequenceEncoder(
...     input_dim=128,
...     hidden_dim=256,
...     num_layers=2,
...     bidirectional=True
... )
Source code in src/formed/integrations/torch/modules/encoders.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def __init__(
    self,
    input_dim: int,
    hidden_dim: int,
    num_layers: int = 1,
    bidirectional: bool = False,
    dropout: float = 0.0,
    batch_first: bool = True,
) -> None:
    super().__init__()
    self._input_dim = input_dim
    self._hidden_dim = hidden_dim
    self._num_layers = num_layers
    self._bidirectional = bidirectional

    self.lstm = nn.LSTM(
        input_size=input_dim,
        hidden_size=hidden_dim,
        num_layers=num_layers,
        bidirectional=bidirectional,
        dropout=dropout if num_layers > 1 else 0.0,
        batch_first=batch_first,
    )

lstm instance-attribute

lstm = LSTM(
    input_size=input_dim,
    hidden_size=hidden_dim,
    num_layers=num_layers,
    bidirectional=bidirectional,
    dropout=dropout if num_layers > 1 else 0.0,
    batch_first=batch_first,
)

forward

forward(inputs, mask=None)

Encode input sequence.

PARAMETER DESCRIPTION
inputs

Input of shape (batch_size, seq_len, input_dim).

TYPE: Tensor

mask

Optional mask of shape (batch_size, seq_len).

TYPE: Optional[Tensor] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Encoded sequence of shape (batch_size, seq_len, output_dim).

Source code in src/formed/integrations/torch/modules/encoders.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def forward(
    self,
    inputs: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Encode input sequence.

    Args:
        inputs: Input of shape `(batch_size, seq_len, input_dim)`.
        mask: Optional mask of shape `(batch_size, seq_len)`.

    Returns:
        Encoded sequence of shape `(batch_size, seq_len, output_dim)`.

    """
    if mask is not None:
        # Pack padded sequence for efficiency
        lengths = mask.sum(dim=1).cpu()
        packed = nn.utils.rnn.pack_padded_sequence(inputs, lengths, batch_first=True, enforce_sorted=False)
        output, _ = self.lstm(packed)
        output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
    else:
        output, _ = self.lstm(inputs)

    return output

get_input_dim

get_input_dim()
Source code in src/formed/integrations/torch/modules/encoders.py
165
166
def get_input_dim(self) -> int:
    return self._input_dim

get_output_dim

get_output_dim()
Source code in src/formed/integrations/torch/modules/encoders.py
168
169
def get_output_dim(self) -> int:
    return self._hidden_dim * (2 if self._bidirectional else 1)

GRUSequenceEncoder

GRUSequenceEncoder(
    input_dim,
    hidden_dim,
    num_layers=1,
    bidirectional=False,
    dropout=0.0,
    batch_first=True,
)

Bases: BaseSequenceEncoder

GRU-based sequence encoder.

PARAMETER DESCRIPTION
input_dim

Input dimension.

TYPE: int

hidden_dim

Hidden state dimension.

TYPE: int

num_layers

Number of GRU layers.

TYPE: int DEFAULT: 1

bidirectional

Whether to use bidirectional GRU.

TYPE: bool DEFAULT: False

dropout

Dropout rate between layers.

TYPE: float DEFAULT: 0.0

batch_first

Whether input is batch-first (default: True).

TYPE: bool DEFAULT: True

Examples:

>>> encoder = GRUSequenceEncoder(
...     input_dim=128,
...     hidden_dim=256,
...     num_layers=2,
...     bidirectional=True
... )
Source code in src/formed/integrations/torch/modules/encoders.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def __init__(
    self,
    input_dim: int,
    hidden_dim: int,
    num_layers: int = 1,
    bidirectional: bool = False,
    dropout: float = 0.0,
    batch_first: bool = True,
) -> None:
    super().__init__()
    self._input_dim = input_dim
    self._hidden_dim = hidden_dim
    self._num_layers = num_layers
    self._bidirectional = bidirectional

    self.gru = nn.GRU(
        input_size=input_dim,
        hidden_size=hidden_dim,
        num_layers=num_layers,
        bidirectional=bidirectional,
        dropout=dropout if num_layers > 1 else 0.0,
        batch_first=batch_first,
    )

gru instance-attribute

gru = GRU(
    input_size=input_dim,
    hidden_size=hidden_dim,
    num_layers=num_layers,
    bidirectional=bidirectional,
    dropout=dropout if num_layers > 1 else 0.0,
    batch_first=batch_first,
)

forward

forward(inputs, mask=None)

Encode input sequence.

PARAMETER DESCRIPTION
inputs

Input of shape (batch_size, seq_len, input_dim).

TYPE: Tensor

mask

Optional mask of shape (batch_size, seq_len).

TYPE: Optional[Tensor] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Encoded sequence of shape (batch_size, seq_len, output_dim).

Source code in src/formed/integrations/torch/modules/encoders.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
def forward(
    self,
    inputs: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Encode input sequence.

    Args:
        inputs: Input of shape `(batch_size, seq_len, input_dim)`.
        mask: Optional mask of shape `(batch_size, seq_len)`.

    Returns:
        Encoded sequence of shape `(batch_size, seq_len, output_dim)`.

    """
    if mask is not None:
        # Pack padded sequence for efficiency
        lengths = mask.sum(dim=1).cpu()
        packed = nn.utils.rnn.pack_padded_sequence(inputs, lengths, batch_first=True, enforce_sorted=False)
        output, _ = self.gru(packed)
        output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
    else:
        output, _ = self.gru(inputs)

    return output

get_input_dim

get_input_dim()
Source code in src/formed/integrations/torch/modules/encoders.py
244
245
def get_input_dim(self) -> int:
    return self._input_dim

get_output_dim

get_output_dim()
Source code in src/formed/integrations/torch/modules/encoders.py
247
248
def get_output_dim(self) -> int:
    return self._hidden_dim * (2 if self._bidirectional else 1)

ResidualSequenceEncoder

ResidualSequenceEncoder(encoder)

Bases: BaseSequenceEncoder

Residual wrapper for sequence encoders.

Adds the input to the encoder output (residual connection). Requires input and output dimensions to match.

PARAMETER DESCRIPTION
encoder

Base encoder to wrap. Must have matching input and output dimensions.

TYPE: BaseSequenceEncoder

Examples:

>>> from formed.integrations.torch.modules.encoders import (
...     ResidualSequenceEncoder,
...     LSTMSequenceEncoder
... )
>>>
>>> # Wrap LSTM with residual connection
>>> base_encoder = LSTMSequenceEncoder(input_dim=128, hidden_dim=128)
>>> encoder = ResidualSequenceEncoder(encoder=base_encoder)
Source code in src/formed/integrations/torch/modules/encoders.py
273
274
275
276
277
def __init__(self, encoder: BaseSequenceEncoder) -> None:
    assert encoder.get_input_dim() == encoder.get_output_dim()

    super().__init__()
    self._encoder = encoder

forward

forward(inputs, mask=None)
Source code in src/formed/integrations/torch/modules/encoders.py
279
280
281
282
283
284
285
def forward(
    self,
    inputs: torch.Tensor,
    mask: torch.Tensor | None = None,
) -> torch.Tensor:
    encoded = self._encoder(inputs, mask=mask)
    return inputs + encoded

get_input_dim

get_input_dim()
Source code in src/formed/integrations/torch/modules/encoders.py
287
288
def get_input_dim(self) -> int:
    return self._encoder.get_input_dim()

get_output_dim

get_output_dim()
Source code in src/formed/integrations/torch/modules/encoders.py
290
291
def get_output_dim(self) -> int:
    return self._encoder.get_output_dim()

FeedForwardSequenceEncoder

FeedForwardSequenceEncoder(feedforward)

Bases: BaseSequenceEncoder

Position-wise feedforward sequence encoder.

Applies a feedforward network independently to each position in the sequence. The same transformation is applied at each position (no cross-position interaction).

PARAMETER DESCRIPTION
feedforward

Feedforward network to apply at each position.

TYPE: FeedForward

Examples:

>>> from formed.integrations.torch.modules.encoders import (
...     FeedForwardSequenceEncoder
... )
>>> from formed.integrations.torch.modules.feedforward import FeedForward
>>>
>>> # Apply feedforward to each position independently
>>> feedforward = FeedForward(input_dim=128, hidden_dims=[256, 128])
>>> encoder = FeedForwardSequenceEncoder(feedforward=feedforward)
Source code in src/formed/integrations/torch/modules/encoders.py
316
317
318
def __init__(self, feedforward: FeedForward) -> None:
    super().__init__()
    self._feedforward = feedforward

forward

forward(inputs, mask=None)
Source code in src/formed/integrations/torch/modules/encoders.py
320
321
322
323
324
325
326
327
328
329
330
def forward(
    self,
    inputs: torch.Tensor,
    mask: torch.Tensor | None = None,
) -> torch.Tensor:
    del mask

    original_shape = inputs.shape
    flattened_inputs = inputs.view(-1, original_shape[-1])
    encoded = self._feedforward(flattened_inputs)
    return encoded.view(*original_shape[:-1], -1)

get_input_dim

get_input_dim()
Source code in src/formed/integrations/torch/modules/encoders.py
332
333
def get_input_dim(self) -> int:
    return self._feedforward.get_input_dim()

get_output_dim

get_output_dim()
Source code in src/formed/integrations/torch/modules/encoders.py
335
336
def get_output_dim(self) -> int:
    return self._feedforward.get_output_dim()

GatedCnnSequenceEncoder

GatedCnnSequenceEncoder(
    input_dim, layers, output_dim=None, dropout=0.0
)

Bases: BaseSequenceEncoder

Gated Convolutional Neural Network sequence encoder.

Uses stacked residual blocks with gated linear units (GLU) for efficient sequence modeling. Processes sequences in both forward and backward directions, then concatenates the results for bidirectional context capture.

Based on "Language Modeling with Gated Convolutional Networks" (Dauphin et al., 2017).

PARAMETER DESCRIPTION
input_dim

Input dimension.

TYPE: int

layers

List of layer configurations for each residual block. Each block is a list of Layer(kernel_size, output_dim, dilation).

TYPE: Sequence[Sequence[Layer]]

output_dim

Optional output dimension. If provided, applies linear projection. Default is input_dim * 2 (concatenation of forward + backward).

TYPE: Optional[int] DEFAULT: None

dropout

Dropout rate applied to the first convolution of each block.

TYPE: float DEFAULT: 0.0

Examples:

>>> # Simple gated CNN encoder
>>> encoder = GatedCnnSequenceEncoder(
...     input_dim=128,
...     layers=[
...         [GatedCnnSequenceEncoder.Layer(kernel_size=3, output_dim=128)],
...         [GatedCnnSequenceEncoder.Layer(kernel_size=3, output_dim=128)],
...     ]
... )
>>>
>>> # With dilated convolutions for larger receptive field
>>> encoder = GatedCnnSequenceEncoder(
...     input_dim=128,
...     layers=[
...         [GatedCnnSequenceEncoder.Layer(kernel_size=2, output_dim=128, dilation=1)],
...         [GatedCnnSequenceEncoder.Layer(kernel_size=2, output_dim=128, dilation=2)],
...         [GatedCnnSequenceEncoder.Layer(kernel_size=2, output_dim=128, dilation=4)],
...     ],
...     output_dim=256,
...     dropout=0.1
... )
Source code in src/formed/integrations/torch/modules/encoders.py
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
def __init__(
    self,
    input_dim: int,
    layers: Sequence[Sequence["GatedCnnSequenceEncoder.Layer"]],
    output_dim: Optional[int] = None,
    dropout: float = 0.0,
) -> None:
    super().__init__()

    self._forward_residual_blocks = torch.nn.ModuleList()
    self._backward_residual_blocks = torch.nn.ModuleList()
    self._input_dim = input_dim
    self._output_dim = output_dim or input_dim * 2

    for layer in layers:
        self._forward_residual_blocks.append(
            GatedCnnSequenceEncoder.ResidualBlock(input_dim, layer, "forward", dropout=dropout)
        )
        self._backward_residual_blocks.append(
            GatedCnnSequenceEncoder.ResidualBlock(input_dim, layer, "backward", dropout=dropout)
        )

    self._projection: Optional[torch.nn.Linear] = None
    if output_dim:
        self._projection = torch.nn.Linear(input_dim * 2, output_dim)

Layer

Bases: NamedTuple

Configuration for a single convolutional layer.

ATTRIBUTE DESCRIPTION
kernel_size

Size of the convolution kernel.

TYPE: int

output_dim

Output dimension of the layer. Must match input_dim for residual connections to work.

TYPE: int

dilation

Dilation rate for the convolution. When dilation > 1, kernel_size must be 2.

TYPE: int

kernel_size instance-attribute
kernel_size
output_dim instance-attribute
output_dim
dilation class-attribute instance-attribute
dilation = 1

ResidualBlock

ResidualBlock(
    input_dim,
    layers,
    direction,
    do_weight_norm=True,
    dropout=0.0,
)

Bases: Module

Residual block with gated convolutions for sequence encoding.

Stacks multiple gated convolutional layers with residual connections. Supports causal masking via directional processing (forward/backward).

PARAMETER DESCRIPTION
input_dim

Input dimension. Must match output dimension of all layers for residual connection.

TYPE: int

layers

Sequence of Layer configurations defining the convolutional stack.

TYPE: Sequence[Layer]

direction

Direction of causal masking ("forward" or "backward").

TYPE: Literal['forward', 'backward']

do_weight_norm

Whether to apply weight normalization to convolutions.

TYPE: bool DEFAULT: True

dropout

Dropout rate applied to the first convolution.

TYPE: float DEFAULT: 0.0

Source code in src/formed/integrations/torch/modules/encoders.py
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def __init__(
    self,
    input_dim: int,
    layers: Sequence["GatedCnnSequenceEncoder.Layer"],
    direction: Literal["forward", "backward"],
    do_weight_norm: bool = True,
    dropout: float = 0.0,
) -> None:
    super().__init__()

    self.dropout = dropout
    self._convolutions = torch.nn.ModuleList()
    last_dim = input_dim
    for k, layer in enumerate(layers):
        if layer.dilation == 1:
            conv = torch.nn.Conv1d(
                in_channels=last_dim,
                out_channels=layer.output_dim * 2,
                kernel_size=layer.kernel_size,
                stride=1,
                padding=layer[0] - 1,
                bias=True,
            )
        else:
            assert layer.kernel_size == 2, "only support kernel = 2 for now"
            conv = torch.nn.Conv1d(
                in_channels=last_dim,
                out_channels=layer.output_dim * 2,
                kernel_size=layer.kernel_size,
                stride=1,
                padding=layer.dilation,
                dilation=layer.dilation,
                bias=True,
            )

        if k == 0:
            conv_dropout = dropout
        else:
            conv_dropout = 0.0
        std = math.sqrt((4 * (1.0 - conv_dropout)) / (layer.kernel_size * last_dim))

        conv.weight.data.normal_(0, std=std)
        if conv.bias is not None:
            conv.bias.data.zero_()

        if do_weight_norm:
            conv = torch.nn.utils.weight_norm(conv, name="weight", dim=0)

        self._convolutions.append(conv)
        last_dim = layer.output_dim

    assert last_dim == input_dim

    if direction not in ("forward", "backward"):
        raise ValueError(f"invalid direction: {direction}")
    self._direction = direction
dropout instance-attribute
dropout = dropout
forward
forward(inputs)

Apply gated convolutions with residual connection.

PARAMETER DESCRIPTION
inputs

Input of shape (batch_size, input_dim, seq_len).

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Encoded sequence with residual connection of shape (batch_size, output_dim, seq_len).

Source code in src/formed/integrations/torch/modules/encoders.py
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
    """Apply gated convolutions with residual connection.

    Args:
        inputs: Input of shape `(batch_size, input_dim, seq_len)`.

    Returns:
        Encoded sequence with residual connection of shape `(batch_size, output_dim, seq_len)`.

    """
    output = inputs
    sequence_length = inputs.size(2)
    for k, convolution in enumerate(self._convolutions):
        if k == 0 and self.dropout > 0:
            output = torch.nn.functional.dropout(output, self.dropout, self.training)

        conv_out = convolution(output)

        dims_to_remove = conv_out.size(2) - sequence_length
        if dims_to_remove > 0:
            if self._direction == "forward":
                conv_out = conv_out.narrow(2, 0, sequence_length)
            else:
                conv_out = conv_out.narrow(2, dims_to_remove, sequence_length)

        output = torch.nn.functional.glu(conv_out, dim=1)

    return (output + inputs) * math.sqrt(0.5)

get_input_dim

get_input_dim()
Source code in src/formed/integrations/torch/modules/encoders.py
525
526
def get_input_dim(self) -> int:
    return self._input_dim

get_output_dim

get_output_dim()
Source code in src/formed/integrations/torch/modules/encoders.py
528
529
def get_output_dim(self) -> int:
    return self._output_dim

forward

forward(inputs, mask=None)

Encode input sequence using gated CNN.

PARAMETER DESCRIPTION
inputs

Input of shape (batch_size, seq_len, input_dim).

TYPE: Tensor

mask

Optional mask of shape (batch_size, seq_len). True indicates valid positions, False indicates padding.

TYPE: Tensor | None DEFAULT: None

RETURNS DESCRIPTION
Tensor

Encoded sequence of shape (batch_size, seq_len, output_dim).

Source code in src/formed/integrations/torch/modules/encoders.py
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
def forward(self, inputs: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
    """Encode input sequence using gated CNN.

    Args:
        inputs: Input of shape `(batch_size, seq_len, input_dim)`.
        mask: Optional mask of shape `(batch_size, seq_len)`.
             True indicates valid positions, False indicates padding.

    Returns:
        Encoded sequence of shape `(batch_size, seq_len, output_dim)`.

    """
    if mask is None:
        mask = torch.ones(*inputs.size()[:-1], dtype=torch.bool, device=inputs.device)
    else:
        # Ensure mask is boolean
        mask = mask.bool()

    transposed_embeddings = torch.transpose(inputs, 1, 2)
    mask_for_fill = ~mask.unsqueeze(1)

    outputs: list[torch.Tensor] = []
    for blocks in (self._forward_residual_blocks, self._backward_residual_blocks):
        out = transposed_embeddings
        for block in blocks:
            out = block(out.masked_fill(mask_for_fill, 0.0))
        outputs.append(out)

    output = torch.cat(outputs, dim=1).transpose(1, 2)
    if self._projection:
        output = self._projection(output)
    return output

StackedSequenceEncoder

StackedSequenceEncoder(encoders)

Bases: BaseSequenceEncoder

Stacks multiple sequence encoders sequentially.

Applies encoders in order, passing the output of each as input to the next. The output dimension of each encoder must match the input dimension of the next.

PARAMETER DESCRIPTION
encoders

List of encoders to apply in sequence. Each encoder's output dimension must match the next encoder's input dimension.

TYPE: list[BaseSequenceEncoder]

Examples:

>>> from formed.integrations.torch.modules.encoders import (
...     StackedSequenceEncoder,
...     LSTMSequenceEncoder,
...     GRUSequenceEncoder,
...     ResidualSequenceEncoder
... )
>>>
>>> # Stack LSTM and GRU
>>> encoders = [
...     LSTMSequenceEncoder(input_dim=128, hidden_dim=128),
...     GRUSequenceEncoder(input_dim=128, hidden_dim=64),
... ]
>>> encoder = StackedSequenceEncoder(encoders=encoders)
>>>
>>> # More complex: LSTM -> Residual LSTM -> GRU
>>> base_lstm = LSTMSequenceEncoder(input_dim=128, hidden_dim=128)
>>> residual_lstm = ResidualSequenceEncoder(encoder=base_lstm)
>>> gru = GRUSequenceEncoder(input_dim=128, hidden_dim=128)
>>> encoder = StackedSequenceEncoder(encoders=[base_lstm, residual_lstm, gru])
Source code in src/formed/integrations/torch/modules/encoders.py
599
600
601
602
603
def __init__(self, encoders: list[BaseSequenceEncoder]) -> None:
    super().__init__()
    self._encoders = torch.nn.ModuleList(encoders)
    self._input_dim = encoders[0].get_input_dim()
    self._output_dim = encoders[-1].get_output_dim()

forward

forward(inputs, mask=None)
Source code in src/formed/integrations/torch/modules/encoders.py
605
606
607
608
609
610
611
612
613
def forward(
    self,
    inputs: torch.Tensor,
    mask: torch.Tensor | None = None,
) -> torch.Tensor:
    x = inputs
    for encoder in self._encoders:
        x = encoder(x, mask=mask)
    return x

get_input_dim

get_input_dim()
Source code in src/formed/integrations/torch/modules/encoders.py
615
616
def get_input_dim(self) -> int:
    return self._input_dim

get_output_dim

get_output_dim()
Source code in src/formed/integrations/torch/modules/encoders.py
618
619
def get_output_dim(self) -> int:
    return self._output_dim

ConcatSequenceEncoder

ConcatSequenceEncoder(encoders)

Bases: BaseSequenceEncoder

Concatenates outputs from multiple sequence encoders.

Applies multiple encoders in parallel to the same input and concatenates their outputs along the feature dimension. All encoders receive the same input tensor.

PARAMETER DESCRIPTION
encoders

List of encoders to apply in parallel. All encoders must have the same input dimension.

TYPE: list[BaseSequenceEncoder]

Examples:

>>> from formed.integrations.torch.modules.encoders import (
...     ConcatSequenceEncoder,
...     LSTMSequenceEncoder,
...     GRUSequenceEncoder
... )
>>>
>>> # Concatenate LSTM and GRU outputs
>>> encoders = [
...     LSTMSequenceEncoder(input_dim=128, hidden_dim=64),
...     GRUSequenceEncoder(input_dim=128, hidden_dim=64),
... ]
>>> encoder = ConcatSequenceEncoder(encoders=encoders)
Source code in src/formed/integrations/torch/modules/encoders.py
649
650
651
652
653
def __init__(self, encoders: list[BaseSequenceEncoder]) -> None:
    super().__init__()
    self._encoders = torch.nn.ModuleList(encoders)
    self._input_dim = sum(encoder.get_input_dim() for encoder in encoders)
    self._output_dim = sum(encoder.get_output_dim() for encoder in encoders)

get_input_dim

get_input_dim()

Get the expected input dimension.

RETURNS DESCRIPTION
int

Sum of input dimensions across all encoders.

Source code in src/formed/integrations/torch/modules/encoders.py
655
656
657
658
659
660
661
662
def get_input_dim(self) -> int:
    """Get the expected input dimension.

    Returns:
        Sum of input dimensions across all encoders.

    """
    return self._input_dim

get_output_dim

get_output_dim()

Get the output dimension.

RETURNS DESCRIPTION
int

Sum of output dimensions across all encoders.

Source code in src/formed/integrations/torch/modules/encoders.py
664
665
666
667
668
669
670
671
def get_output_dim(self) -> int:
    """Get the output dimension.

    Returns:
        Sum of output dimensions across all encoders.

    """
    return self._output_dim

forward

forward(inputs, mask=None)

Encode input sequence by concatenating outputs from all encoders.

PARAMETER DESCRIPTION
inputs

Input of shape (batch_size, seq_len, input_dim).

TYPE: Tensor

mask

Optional mask of shape (batch_size, seq_len).

TYPE: Tensor | None DEFAULT: None

RETURNS DESCRIPTION
Tensor

Concatenated encoded sequence of shape (batch_size, seq_len, output_dim).

Source code in src/formed/integrations/torch/modules/encoders.py
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
def forward(self, inputs: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
    """Encode input sequence by concatenating outputs from all encoders.

    Args:
        inputs: Input of shape `(batch_size, seq_len, input_dim)`.
        mask: Optional mask of shape `(batch_size, seq_len)`.

    Returns:
        Concatenated encoded sequence of shape `(batch_size, seq_len, output_dim)`.

    """
    outputs = []
    for encoder in self._encoders:
        outputs.append(encoder(inputs, mask=mask))
    return torch.cat(outputs, dim=-1)

WindowConcatSequenceEncoder

WindowConcatSequenceEncoder(
    input_dim, window_size, output_dim=None
)

Bases: BaseSequenceEncoder

Concatenates context window features for each position in the sequence.

For each position, concatenates the embeddings from surrounding positions within a specified window. This creates richer positional representations by explicitly including local context.

PARAMETER DESCRIPTION
input_dim

Input dimension.

TYPE: int

window_size

Size of context window on each side. If int, uses symmetric window. If tuple (left, right), uses asymmetric window.

TYPE: int | tuple[int, int]

output_dim

Optional output dimension. If provided, applies linear projection to the concatenated features. Otherwise, output dimension is (left_window + 1 + right_window) * input_dim.

TYPE: int | None DEFAULT: None

Examples:

>>> # Symmetric 2-position window on each side
>>> encoder = WindowConcatSequenceEncoder(
...     input_dim=128,
...     window_size=2
... )
>>>
>>> # Asymmetric window with projection
>>> encoder = WindowConcatSequenceEncoder(
...     input_dim=128,
...     window_size=(1, 2),
...     output_dim=256
... )
Source code in src/formed/integrations/torch/modules/encoders.py
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
def __init__(
    self,
    input_dim: int,
    window_size: int | tuple[int, int],
    output_dim: int | None = None,
) -> None:
    super().__init__()
    if isinstance(window_size, int):
        window_size = (window_size, window_size)
    if not all(s >= 0 for s in window_size):
        raise ValueError("Window size must be greater than or equal to zero.")
    self._input_dim = input_dim
    self._window_size = window_size
    self._projection: Optional[torch.nn.Linear] = None
    if output_dim is not None:
        self._projection = torch.nn.Linear(
            (sum(window_size) + 1) * input_dim,
            output_dim,
        )

get_input_dim

get_input_dim()

Get the expected input dimension.

RETURNS DESCRIPTION
int

Input dimension of the embeddings.

Source code in src/formed/integrations/torch/modules/encoders.py
742
743
744
745
746
747
748
749
def get_input_dim(self) -> int:
    """Get the expected input dimension.

    Returns:
        Input dimension of the embeddings.

    """
    return self._input_dim

get_output_dim

get_output_dim()

Get the output dimension.

RETURNS DESCRIPTION
int

Output dimension after window concatenation and optional projection.

Source code in src/formed/integrations/torch/modules/encoders.py
751
752
753
754
755
756
757
758
759
760
def get_output_dim(self) -> int:
    """Get the output dimension.

    Returns:
        Output dimension after window concatenation and optional projection.

    """
    if self._projection is not None:
        return self._projection.out_features
    return (sum(self._window_size) + 1) * self._input_dim

forward

forward(inputs, mask=None)

Encode input sequence by concatenating context windows.

PARAMETER DESCRIPTION
inputs

Input of shape (batch_size, seq_len, input_dim).

TYPE: Tensor

mask

Optional mask of shape (batch_size, seq_len). True indicates valid positions, False indicates padding.

TYPE: Tensor | None DEFAULT: None

RETURNS DESCRIPTION
Tensor

Window-concatenated sequence of shape (batch_size, seq_len, output_dim).

Source code in src/formed/integrations/torch/modules/encoders.py
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
def forward(self, inputs: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
    """Encode input sequence by concatenating context windows.

    Args:
        inputs: Input of shape (batch_size, seq_len, input_dim).
        mask: Optional mask of shape (batch_size, seq_len).
             True indicates valid positions, False indicates padding.

    Returns:
        Window-concatenated sequence of shape (batch_size, seq_len, output_dim).

    """
    batch_size, max_length, embedding_dim = inputs.size()

    if mask is None:
        mask = torch.ones((batch_size, max_length), dtype=torch.bool, device=inputs.device)

    inputs = inputs * mask.float().unsqueeze(2)

    output = inputs
    lws, rws = self._window_size
    if lws > 0:
        pad = inputs.new_zeros((batch_size, lws, embedding_dim))
        x = torch.cat([pad, inputs], dim=1)
        x = torch.cat([x[:, offset : offset + max_length] for offset in range(lws)], dim=2)
        output = torch.cat([output, x], dim=2)
    if rws > 0:
        pad = inputs.new_zeros((batch_size, rws, embedding_dim))
        x = torch.cat([inputs, pad], dim=1)
        x = torch.cat([x[:, offset : offset + max_length] for offset in range(1, rws + 1)], dim=2)
        output = torch.cat([output, x], dim=2)

    if self._projection is not None:
        output = self._projection(output)

    return output * mask.float().unsqueeze(2)

BasePositionalEncoder

Bases: Module, Registrable, ABC

Abstract base class for positional encoders.

Positional encoders add positional information to sequential data.

forward abstractmethod

forward(inputs, mask=None)

Add positional encoding to input sequence.

PARAMETER DESCRIPTION
inputs

Input sequence of shape (batch_size, seq_len, input_dim).

TYPE: Tensor

mask

Optional mask of shape (batch_size, seq_len).

TYPE: Optional[Tensor] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Position-encoded sequence of shape (batch_size, seq_len, output_dim).

Source code in src/formed/integrations/torch/modules/encoders.py
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
@abc.abstractmethod
def forward(
    self,
    inputs: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Add positional encoding to input sequence.

    Args:
        inputs: Input sequence of shape (batch_size, seq_len, input_dim).
        mask: Optional mask of shape (batch_size, seq_len).

    Returns:
        Position-encoded sequence of shape (batch_size, seq_len, output_dim).

    """
    raise NotImplementedError

get_input_dim abstractmethod

get_input_dim()

Get the expected input dimension.

Source code in src/formed/integrations/torch/modules/encoders.py
825
826
827
828
@abc.abstractmethod
def get_input_dim(self) -> int:
    """Get the expected input dimension."""
    raise NotImplementedError

get_output_dim abstractmethod

get_output_dim()

Get the output dimension.

Source code in src/formed/integrations/torch/modules/encoders.py
830
831
832
833
@abc.abstractmethod
def get_output_dim(self) -> int:
    """Get the output dimension."""
    raise NotImplementedError

SinusoidalPositionalEncoder

SinusoidalPositionalEncoder(
    input_dim, max_len=5000, dropout=0.0
)

Bases: BasePositionalEncoder

Sinusoidal positional encoding.

Uses sine and cosine functions of different frequencies to encode position information, as introduced in "Attention Is All You Need".

PARAMETER DESCRIPTION
input_dim

Dimension of the embeddings.

TYPE: int

max_len

Maximum sequence length to pre-compute.

TYPE: int DEFAULT: 5000

dropout

Dropout rate to apply after adding positional encoding.

TYPE: float DEFAULT: 0.0

Examples:

>>> encoder = SinusoidalPositionalEncoder(
...     input_dim=512,
...     max_len=5000,
...     dropout=0.1
... )
Source code in src/formed/integrations/torch/modules/encoders.py
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
def __init__(
    self,
    input_dim: int,
    max_len: int = 5000,
    dropout: float = 0.0,
) -> None:
    super().__init__()
    self._input_dim = input_dim
    self._max_len = max_len

    # Create positional encoding matrix
    position = torch.arange(max_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, input_dim, 2) * (-torch.log(torch.tensor(10000.0)) / input_dim))

    pe = torch.zeros(1, max_len, input_dim)
    pe[0, :, 0::2] = torch.sin(position * div_term)
    pe[0, :, 1::2] = torch.cos(position * div_term)

    self.register_buffer("pe", pe)
    self.dropout = nn.Dropout(p=dropout)

pe instance-attribute

pe

dropout instance-attribute

dropout = Dropout(p=dropout)

forward

forward(inputs, mask=None)

Add sinusoidal positional encoding to inputs.

PARAMETER DESCRIPTION
inputs

Input of shape (batch_size, seq_len, input_dim).

TYPE: Tensor

mask

Optional mask of shape (batch_size, seq_len).

TYPE: Optional[Tensor] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Position-encoded sequence of shape (batch_size, seq_len, input_dim).

Source code in src/formed/integrations/torch/modules/encoders.py
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
def forward(
    self,
    inputs: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Add sinusoidal positional encoding to inputs.

    Args:
        inputs: Input of shape (batch_size, seq_len, input_dim).
        mask: Optional mask of shape (batch_size, seq_len).

    Returns:
        Position-encoded sequence of shape (batch_size, seq_len, input_dim).

    """
    seq_len = inputs.size(1)
    if seq_len > self._max_len:
        raise ValueError(f"Sequence length {seq_len} exceeds maximum length {self._max_len}")

    output = inputs + self.pe[:, :seq_len, :]
    return self.dropout(output)

get_input_dim

get_input_dim()
Source code in src/formed/integrations/torch/modules/encoders.py
909
910
def get_input_dim(self) -> int:
    return self._input_dim

get_output_dim

get_output_dim()
Source code in src/formed/integrations/torch/modules/encoders.py
912
913
def get_output_dim(self) -> int:
    return self._input_dim

RotaryPositionalEncoder

RotaryPositionalEncoder(
    input_dim, max_len=2048, base=10000.0
)

Bases: BasePositionalEncoder

Rotary positional encoding (RoPE).

Applies rotary position embeddings by rotating pairs of dimensions in the feature space, as introduced in "RoFormer: Enhanced Transformer with Rotary Position Embedding".

PARAMETER DESCRIPTION
input_dim

Dimension of the embeddings (must be even).

TYPE: int

max_len

Maximum sequence length to pre-compute.

TYPE: int DEFAULT: 2048

base

Base for the geometric progression (default: 10000).

TYPE: float DEFAULT: 10000.0

Examples:

>>> encoder = RotaryPositionalEncoder(
...     input_dim=512,
...     max_len=2048
... )
Source code in src/formed/integrations/torch/modules/encoders.py
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
def __init__(
    self,
    input_dim: int,
    max_len: int = 2048,
    base: float = 10000.0,
) -> None:
    super().__init__()
    if input_dim % 2 != 0:
        raise ValueError(f"input_dim must be even, got {input_dim}")

    self._input_dim = input_dim
    self._max_len = max_len

    # Compute inverse frequencies
    inv_freq = 1.0 / (base ** (torch.arange(0, input_dim, 2).float() / input_dim))
    self.register_buffer("inv_freq", inv_freq)

    # Pre-compute cos and sin for max_len positions
    t = torch.arange(max_len).float()
    freqs = torch.outer(t, inv_freq)
    emb = torch.cat((freqs, freqs), dim=-1)

    self.register_buffer("cos_cached", emb.cos()[None, :, :])
    self.register_buffer("sin_cached", emb.sin()[None, :, :])

inv_freq instance-attribute

inv_freq

cos_cached instance-attribute

cos_cached

sin_cached instance-attribute

sin_cached

forward

forward(inputs, mask=None)

Apply rotary positional encoding to inputs.

PARAMETER DESCRIPTION
inputs

Input of shape (batch_size, seq_len, input_dim).

TYPE: Tensor

mask

Optional mask of shape (batch_size, seq_len).

TYPE: Optional[Tensor] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Position-encoded sequence of shape (batch_size, seq_len, input_dim).

Source code in src/formed/integrations/torch/modules/encoders.py
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
def forward(
    self,
    inputs: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Apply rotary positional encoding to inputs.

    Args:
        inputs: Input of shape (batch_size, seq_len, input_dim).
        mask: Optional mask of shape (batch_size, seq_len).

    Returns:
        Position-encoded sequence of shape (batch_size, seq_len, input_dim).

    """
    seq_len = inputs.size(1)
    if seq_len > self._max_len:
        raise ValueError(f"Sequence length {seq_len} exceeds maximum length {self._max_len}")

    cos = self.cos_cached[:, :seq_len, :]
    sin = self.sin_cached[:, :seq_len, :]

    output = (inputs * cos) + (self._rotate_half(inputs) * sin)
    return output

get_input_dim

get_input_dim()
Source code in src/formed/integrations/torch/modules/encoders.py
995
996
def get_input_dim(self) -> int:
    return self._input_dim

get_output_dim

get_output_dim()
Source code in src/formed/integrations/torch/modules/encoders.py
998
999
def get_output_dim(self) -> int:
    return self._input_dim

LearnablePositionalEncoder

LearnablePositionalEncoder(
    input_dim, max_len=1024, dropout=0.0
)

Bases: BasePositionalEncoder

Learnable positional embeddings.

Uses a learnable embedding table to encode position information, similar to token embeddings.

PARAMETER DESCRIPTION
input_dim

Dimension of the embeddings.

TYPE: int

max_len

Maximum sequence length (vocabulary size for positions).

TYPE: int DEFAULT: 1024

dropout

Dropout rate to apply after adding positional encoding.

TYPE: float DEFAULT: 0.0

Examples:

>>> encoder = LearnablePositionalEncoder(
...     input_dim=512,
...     max_len=1024,
...     dropout=0.1
... )
Source code in src/formed/integrations/torch/modules/encoders.py
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
def __init__(
    self,
    input_dim: int,
    max_len: int = 1024,
    dropout: float = 0.0,
) -> None:
    super().__init__()
    self._input_dim = input_dim
    self._max_len = max_len

    self.position_embeddings = nn.Embedding(max_len, input_dim)
    self.dropout = nn.Dropout(p=dropout)

position_embeddings instance-attribute

position_embeddings = Embedding(max_len, input_dim)

dropout instance-attribute

dropout = Dropout(p=dropout)

forward

forward(inputs, mask=None)

Add learnable positional encoding to inputs.

PARAMETER DESCRIPTION
inputs

Input of shape (batch_size, seq_len, input_dim).

TYPE: Tensor

mask

Optional mask of shape (batch_size, seq_len).

TYPE: Optional[Tensor] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Position-encoded sequence of shape (batch_size, seq_len, input_dim).

Source code in src/formed/integrations/torch/modules/encoders.py
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
def forward(
    self,
    inputs: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Add learnable positional encoding to inputs.

    Args:
        inputs: Input of shape (batch_size, seq_len, input_dim).
        mask: Optional mask of shape (batch_size, seq_len).

    Returns:
        Position-encoded sequence of shape (batch_size, seq_len, input_dim).

    """
    seq_len = inputs.size(1)
    if seq_len > self._max_len:
        raise ValueError(f"Sequence length {seq_len} exceeds maximum length {self._max_len}")

    position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs.device)
    position_ids = position_ids.unsqueeze(0).expand(inputs.size(0), -1)

    position_embeddings = self.position_embeddings(position_ids)
    output = inputs + position_embeddings
    return self.dropout(output)

get_input_dim

get_input_dim()
Source code in src/formed/integrations/torch/modules/encoders.py
1062
1063
def get_input_dim(self) -> int:
    return self._input_dim

get_output_dim

get_output_dim()
Source code in src/formed/integrations/torch/modules/encoders.py
1065
1066
def get_output_dim(self) -> int:
    return self._input_dim

TransformerEncoder

TransformerEncoder(
    input_dim,
    num_heads,
    num_layers,
    feedforward_dim,
    dropout=0.1,
    positional_encoder=None,
    attention_mask=None,
    activation="relu",
    layer_norm_eps=1e-05,
    batch_first=True,
)

Bases: BaseSequenceEncoder

Transformer-based sequence encoder.

Uses stacked TransformerEncoderLayers with positional encoding and configurable attention masking via dependency injection.

PARAMETER DESCRIPTION
input_dim

Dimension of the embeddings (d_model).

TYPE: int

num_heads

Number of attention heads.

TYPE: int

num_layers

Number of transformer layers.

TYPE: int

feedforward_dim

Dimension of feedforward network.

TYPE: int

dropout

Dropout rate.

TYPE: float DEFAULT: 0.1

positional_encoder

Optional positional encoder to add position information.

TYPE: Optional[BasePositionalEncoder] DEFAULT: None

attention_mask

Optional mask generator for self-attention.

TYPE: Optional[BaseAttentionMask] DEFAULT: None

activation

Activation function (default: "relu").

TYPE: str DEFAULT: 'relu'

layer_norm_eps

Epsilon for layer normalization.

TYPE: float DEFAULT: 1e-05

batch_first

Whether input is batch-first (default: True).

TYPE: bool DEFAULT: True

Examples:

>>> from formed.integrations.torch.modules.encoders import (
...     TransformerEncoder,
...     SinusoidalPositionalEncoder,
...     CausalMask
... )
>>>
>>> # Standard transformer encoder
>>> encoder = TransformerEncoder(
...     input_dim=512,
...     num_heads=8,
...     num_layers=6,
...     feedforward_dim=2048,
...     dropout=0.1,
...     positional_encoder=SinusoidalPositionalEncoder(input_dim=512)
... )
>>>
>>> # Transformer with causal masking (for autoregressive tasks)
>>> causal_encoder = TransformerEncoder(
...     input_dim=512,
...     num_heads=8,
...     num_layers=6,
...     feedforward_dim=2048,
...     dropout=0.1,
...     positional_encoder=SinusoidalPositionalEncoder(input_dim=512),
...     attention_mask=CausalMask()
... )
Source code in src/formed/integrations/torch/modules/encoders.py
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
def __init__(
    self,
    input_dim: int,
    num_heads: int,
    num_layers: int,
    feedforward_dim: int,
    dropout: float = 0.1,
    positional_encoder: Optional[BasePositionalEncoder] = None,
    attention_mask: Optional[BaseAttentionMask] = None,
    activation: str = "relu",
    layer_norm_eps: float = 1e-5,
    batch_first: bool = True,
) -> None:
    super().__init__()
    self._input_dim = input_dim
    self._positional_encoder = positional_encoder
    self._attention_mask = attention_mask
    self._batch_first = batch_first

    # Create transformer encoder layers
    encoder_layer = nn.TransformerEncoderLayer(
        d_model=input_dim,
        nhead=num_heads,
        dim_feedforward=feedforward_dim,
        dropout=dropout,
        activation=activation,
        layer_norm_eps=layer_norm_eps,
        batch_first=batch_first,
        norm_first=False,
    )

    self.transformer_encoder = nn.TransformerEncoder(
        encoder_layer=encoder_layer,
        num_layers=num_layers,
    )

transformer_encoder instance-attribute

transformer_encoder = TransformerEncoder(
    encoder_layer=encoder_layer, num_layers=num_layers
)

forward

forward(inputs, mask=None)

Encode input sequence using transformer.

PARAMETER DESCRIPTION
inputs

Input of shape (batch_size, seq_len, input_dim) if batch_first=True, or (seq_len, batch_size, input_dim) if batch_first=False.

TYPE: Tensor

mask

Optional mask of shape (batch_size, seq_len) where 1=valid, 0=padding.

TYPE: Optional[Tensor] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Encoded sequence of same shape as input.

Source code in src/formed/integrations/torch/modules/encoders.py
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
def forward(
    self,
    inputs: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Encode input sequence using transformer.

    Args:
        inputs: Input of shape (batch_size, seq_len, input_dim) if batch_first=True,
               or (seq_len, batch_size, input_dim) if batch_first=False.
        mask: Optional mask of shape (batch_size, seq_len) where 1=valid, 0=padding.

    Returns:
        Encoded sequence of same shape as input.

    """
    # Apply positional encoding if provided
    if self._positional_encoder is not None:
        inputs = self._positional_encoder(inputs, mask=mask)

    batch_size = inputs.size(0) if self._batch_first else inputs.size(1)
    seq_len = inputs.size(1) if self._batch_first else inputs.size(0)

    # Generate attention mask if generator is provided
    # All attention masks return (seq_len, seq_len) or (batch_size, seq_len, seq_len)
    src_mask = None
    if self._attention_mask is not None:
        src_mask = self._attention_mask(
            seq_len=seq_len,
            batch_size=batch_size,
            device=inputs.device,
            padding_mask=mask,
        )
        if src_mask is not None:
            src_mask = src_mask.to(inputs.device)

    # Generate key padding mask for transformer
    # This is separate from attention_mask and handles padding from input mask
    # TransformerEncoder expects True for positions to be masked
    src_key_padding_mask = None
    if mask is not None:
        # Convert mask: 1=valid -> False (not masked), 0=padding -> True (masked)
        src_key_padding_mask = ~mask.bool()

    # Apply transformer encoder
    output = self.transformer_encoder(
        inputs,
        mask=src_mask,
        src_key_padding_mask=src_key_padding_mask,
    )

    return output

get_input_dim

get_input_dim()
Source code in src/formed/integrations/torch/modules/encoders.py
1207
1208
def get_input_dim(self) -> int:
    return self._input_dim

get_output_dim

get_output_dim()
Source code in src/formed/integrations/torch/modules/encoders.py
1210
1211
def get_output_dim(self) -> int:
    return self._input_dim

formed.integrations.torch.modules.feedforward

Feed-forward neural network modules for PyTorch models.

This module provides feed-forward network layers with support for multiple layers, dropout, layer normalization, and residual connections.

Key Components
  • FeedForward: Multi-layer feed-forward network
Features
  • Configurable activation functions
  • Layer normalization
  • Dropout for regularization
  • Residual connections

Examples:

>>> from formed.integrations.torch.modules import FeedForward
>>> import torch.nn as nn
>>>
>>> # Simple 3-layer feed-forward network
>>> ffn = FeedForward(
...     input_dim=256,
...     hidden_dims=[512, 512, 256],
...     dropout=0.1,
...     activation=nn.GELU()
... )

FeedForward

FeedForward(
    input_dim, hidden_dims, dropout=0.0, activation=ReLU()
)

Bases: Module

A simple feed forward neural network.

PARAMETER DESCRIPTION
input_dim

The dimension of the input.

TYPE: int

hidden_dims

A sequence of integers specifying the dimensions of each layer.

TYPE: Sequence[int]

dropout

The dropout probability. Defaults to 0.0.

TYPE: float DEFAULT: 0.0

activation

The activation function. Defaults to torch.nn.ReLU()

TYPE: Module DEFAULT: ReLU()

Source code in src/formed/integrations/torch/modules/feedforward.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(
    self,
    input_dim: int,
    hidden_dims: Sequence[int],
    dropout: float = 0.0,
    activation: torch.nn.Module = torch.nn.ReLU(),
) -> None:
    super().__init__()
    self._input_dim = input_dim
    self._hidden_dims = hidden_dims
    self._dropout = dropout
    self._activation = activation

    layer_dims = [input_dim] + list(hidden_dims)
    self._layers = torch.nn.ModuleList(
        [
            torch.nn.Sequential(
                torch.nn.Linear(layer_dims[i], layer_dims[i + 1]),
                torch.nn.Dropout(dropout),
                activation,
            )
            for i in range(len(layer_dims) - 1)
        ]
    )

forward

forward(inputs)
PARAMETER DESCRIPTION
inputs

A tensor of shape (batch_size, ..., input_dim).

TYPE: FloatTensor

RETURNS DESCRIPTION
FloatTensor

A tensor of shape (batch_size, ..., hidden_dims[-1]).

Source code in src/formed/integrations/torch/modules/feedforward.py
71
72
73
74
75
76
77
78
79
80
81
82
def forward(self, inputs: torch.FloatTensor) -> torch.FloatTensor:
    """
    Args:
        inputs: A tensor of shape `(batch_size, ..., input_dim)`.

    Returns:
        A tensor of shape `(batch_size, ..., hidden_dims[-1])`.
    """
    output = inputs
    for layer in self._layers:
        output = layer(output)
    return output

get_input_dim

get_input_dim()
Source code in src/formed/integrations/torch/modules/feedforward.py
84
85
def get_input_dim(self) -> int:
    return self._input_dim

get_output_dim

get_output_dim()
Source code in src/formed/integrations/torch/modules/feedforward.py
87
88
def get_output_dim(self) -> int:
    return self._hidden_dims[-1]

formed.integrations.torch.modules.losses

Loss functions for classification tasks.

This module provides loss functions for classification with support for label weighting and different reduction strategies.

Key Components
  • BaseClassificationLoss: Abstract base class for classification losses
  • CrossEntropyLoss: Standard cross-entropy loss with optional weighting

Examples:

>>> from formed.integrations.torch.modules import CrossEntropyLoss
>>> import torch
>>>
>>> # Simple cross-entropy
>>> loss_fn = CrossEntropyLoss()
>>> logits = torch.randn(4, 10)  # (batch_size, num_classes)
>>> labels = torch.randint(0, 10, (4,))  # (batch_size,)
>>> loss = loss_fn(logits, labels)
>>>
>>> # With label weighting
>>> from formed.integrations.torch.modules import StaticLabelWeighter
>>> weighter = StaticLabelWeighter(weights=torch.ones(10))
>>> loss_fn = CrossEntropyLoss(weighter=weighter)

BaseClassificationLoss

Bases: Module, Registrable, Generic[_ParamsT], ABC

Abstract base class for classification loss functions.

A ClassificationLoss defines a strategy for computing loss based on model logits and true labels.

CLASS TYPE PARAMETER DESCRIPTION
_ParamsT

Type of additional parameters used during loss computation.

forward abstractmethod

forward(logits, labels, params=None)

Compute the classification loss.

PARAMETER DESCRIPTION
logits

Model output logits of shape (..., num_classes).

TYPE: Tensor

labels

True target labels of shape (...).

TYPE: Tensor

params

Optional additional parameters for loss computation.

TYPE: Optional[_ParamsT] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Computed loss as a scalar tensor.

Source code in src/formed/integrations/torch/modules/losses.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
@abc.abstractmethod
def forward(self, logits: torch.Tensor, labels: torch.Tensor, params: Optional[_ParamsT] = None) -> torch.Tensor:
    """Compute the classification loss.

    Args:
        logits: Model output logits of shape `(..., num_classes)`.
        labels: True target labels of shape `(...)`.
        params: Optional additional parameters for loss computation.

    Returns:
        Computed loss as a scalar tensor.

    """
    raise NotImplementedError

CrossEntropyLoss

CrossEntropyLoss(weighter=None, reduce='mean')

Bases: BaseClassificationLoss[_ParamsT]

Cross-entropy loss for classification tasks.

PARAMETER DESCRIPTION
weighter

An optional label weighter to assign weights to each class.

TYPE: Optional[BaseLabelWeighter[_ParamsT]] DEFAULT: None

reduce

Reduction method - "mean" or "sum".

TYPE: Literal['mean', 'sum'] DEFAULT: 'mean'

Examples:

>>> loss_fn = CrossEntropyLoss()
>>> logits = torch.randn(4, 10)
>>> labels = torch.randint(0, 10, (4,))
>>> loss = loss_fn(logits, labels)
Source code in src/formed/integrations/torch/modules/losses.py
89
90
91
92
93
94
95
96
def __init__(
    self,
    weighter: Optional[BaseLabelWeighter[_ParamsT]] = None,
    reduce: Literal["mean", "sum"] = "mean",
) -> None:
    super().__init__()
    self._weighter = weighter
    self._reduce = reduce

forward

forward(logits, labels, params=None)

Compute cross-entropy loss.

PARAMETER DESCRIPTION
logits

Logits of shape (..., num_classes).

TYPE: Tensor

labels

Labels of shape (...).

TYPE: TensorCompatible

params

Optional parameters for the weighter.

TYPE: Optional[_ParamsT] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Loss scalar.

Source code in src/formed/integrations/torch/modules/losses.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def forward(
    self,
    logits: torch.Tensor,
    labels: TensorCompatible,
    params: Optional[_ParamsT] = None,
) -> torch.Tensor:
    """Compute cross-entropy loss.

    Args:
        logits: Logits of shape `(..., num_classes)`.
        labels: Labels of shape `(...)`.
        params: Optional parameters for the weighter.

    Returns:
        Loss scalar.

    """
    labels = ensure_torch_tensor(labels)

    num_classes = logits.shape[-1]
    one_hot_labels = F.one_hot(labels.long(), num_classes=num_classes).float()
    log_probs = F.log_softmax(logits, dim=-1)

    if self._weighter is not None:
        weights = self._weighter(logits, labels, params)
        loss = -(one_hot_labels * log_probs * weights).sum(dim=-1)
    else:
        loss = -(one_hot_labels * log_probs).sum(dim=-1)

    if self._reduce == "mean":
        return loss.mean()
    elif self._reduce == "sum":
        return loss.sum()
    else:
        raise ValueError(f"Unknown reduce operation: {self._reduce}")

BCEWithLogitsLoss

BCEWithLogitsLoss(
    weighter=None, reduce="mean", pos_weight=None
)

Bases: BaseClassificationLoss[_ParamsT]

Binary cross-entropy loss with logits for multilabel classification tasks.

This loss combines a Sigmoid layer and the BCELoss in one single class. This version is more numerically stable than using a plain Sigmoid followed by BCELoss.

PARAMETER DESCRIPTION
weighter

An optional label weighter to assign weights to each class.

TYPE: Optional[BaseLabelWeighter[_ParamsT]] DEFAULT: None

reduce

Reduction method - "mean" or "sum".

TYPE: Literal['mean', 'sum'] DEFAULT: 'mean'

pos_weight

Optional weight for positive examples per class.

TYPE: Optional[TensorCompatible] DEFAULT: None

Examples:

>>> loss_fn = BCEWithLogitsLoss()
>>> logits = torch.randn(4, 10)  # (batch_size, num_classes)
>>> labels = torch.randint(0, 2, (4, 10)).float()  # (batch_size, num_classes)
>>> loss = loss_fn(logits, labels)
Source code in src/formed/integrations/torch/modules/losses.py
155
156
157
158
159
160
161
162
163
164
def __init__(
    self,
    weighter: Optional[BaseLabelWeighter[_ParamsT]] = None,
    reduce: Literal["mean", "sum"] = "mean",
    pos_weight: Optional[TensorCompatible] = None,
) -> None:
    super().__init__()
    self._weighter = weighter
    self._reduce = reduce
    self._pos_weight = ensure_torch_tensor(pos_weight) if pos_weight is not None else None

forward

forward(logits, labels, params=None)

Compute BCE with logits loss.

PARAMETER DESCRIPTION
logits

Logits of shape (..., num_classes).

TYPE: Tensor

labels

Binary labels of shape (..., num_classes).

TYPE: TensorCompatible

params

Optional parameters for the weighter.

TYPE: Optional[_ParamsT] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Loss scalar.

Source code in src/formed/integrations/torch/modules/losses.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def forward(
    self,
    logits: torch.Tensor,
    labels: TensorCompatible,
    params: Optional[_ParamsT] = None,
) -> torch.Tensor:
    """Compute BCE with logits loss.

    Args:
        logits: Logits of shape `(..., num_classes)`.
        labels: Binary labels of shape `(..., num_classes)`.
        params: Optional parameters for the weighter.

    Returns:
        Loss scalar.

    """
    labels = ensure_torch_tensor(labels).float()

    loss = F.binary_cross_entropy_with_logits(logits, labels, pos_weight=self._pos_weight, reduction="none")

    if self._weighter is not None:
        weights = self._weighter(logits, labels, params)
        loss = loss * weights

    if self._reduce == "mean":
        return loss.mean()
    elif self._reduce == "sum":
        return loss.sum()
    else:
        raise ValueError(f"Unknown reduce operation: {self._reduce}")

BaseRegressionLoss

Bases: Module, Registrable, Generic[_ParamsT], ABC

Abstract base class for regression loss functions.

A RegressionLoss defines a strategy for computing loss based on model predictions and true labels.

CLASS TYPE PARAMETER DESCRIPTION
_ParamsT

Type of additional parameters used during loss computation.

forward abstractmethod

forward(predictions, labels, params=None)

Compute the regression loss.

PARAMETER DESCRIPTION
predictions

Model output predictions of shape (...).

TYPE: Tensor

labels

True target labels of shape (...).

TYPE: Tensor

params

Optional additional parameters for loss computation.

TYPE: Optional[_ParamsT] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Computed loss as a scalar tensor.

Source code in src/formed/integrations/torch/modules/losses.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
@abc.abstractmethod
def forward(
    self, predictions: torch.Tensor, labels: torch.Tensor, params: Optional[_ParamsT] = None
) -> torch.Tensor:
    """Compute the regression loss.

    Args:
        predictions: Model output predictions of shape `(...)`.
        labels: True target labels of shape `(...)`.
        params: Optional additional parameters for loss computation.

    Returns:
        Computed loss as a scalar tensor.

    """
    raise NotImplementedError

MeanSquaredErrorLoss

MeanSquaredErrorLoss(reduce='mean')

Bases: BaseRegressionLoss[_ParamsT]

Mean Squared Error (MSE) loss for regression tasks.

PARAMETER DESCRIPTION
reduce

Reduction method - "mean" or "sum".

TYPE: Literal['mean', 'sum'] DEFAULT: 'mean'

Examples:

>>> loss_fn = MeanSquaredErrorLoss()
>>> predictions = torch.randn(4)
>>> labels = torch.randn(4)
>>> loss = loss_fn(predictions, labels)
Source code in src/formed/integrations/torch/modules/losses.py
247
248
249
250
251
252
def __init__(
    self,
    reduce: Literal["mean", "sum"] = "mean",
) -> None:
    super().__init__()
    self._reduce = reduce

forward

forward(predictions, labels, params=None)

Compute MSE loss.

PARAMETER DESCRIPTION
predictions

Predictions of shape (...).

TYPE: Tensor

labels

Labels of shape (...).

TYPE: TensorCompatible

params

Ignored.

TYPE: Optional[_ParamsT] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Loss scalar.

Source code in src/formed/integrations/torch/modules/losses.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def forward(
    self,
    predictions: torch.Tensor,
    labels: TensorCompatible,
    params: Optional[_ParamsT] = None,
) -> torch.Tensor:
    """Compute MSE loss.

    Args:
        predictions: Predictions of shape `(...)`.
        labels: Labels of shape `(...)`.
        params: Ignored.

    Returns:
        Loss scalar.

    """
    labels = ensure_torch_tensor(labels)

    loss = (predictions - labels).pow(2)

    if self._reduce == "mean":
        return loss.mean()
    elif self._reduce == "sum":
        return loss.sum()
    else:
        raise ValueError(f"Unknown reduce operation: {self._reduce}")

formed.integrations.torch.modules.masks

Attention mask generation for transformer models. This module provides reusable attention mask generators for transformer-based models in PyTorch. Attention masks control which positions in a sequence can attend to which other positions, enabling various attention patterns such as causal masking for autoregressive models or sliding window attention for long sequences.

Key Components
  • BaseAttentionMask: Abstract base class for attention mask generators
  • CausalMask: Generates causal (autoregressive) attention masks
  • SlidingWindowAttentionMask: Generates sliding window attention masks
  • CombinedMask: Combines multiple attention masks into a single mask
Features
  • Standardized attention mask format compatible with PyTorch Transformer modules
  • Support for batch-wise and sequence-wise masks
  • Easily extensible via registration system for custom masks

Examples:

>>> from formed.integrations.torch.modules import CausalMask
>>>
>>> # Create a causal mask generator
>>> mask_generator = CausalMask()
>>>
>>> # Generate a causal mask for sequence length 5 and batch size 2
>>> mask = mask_generator(seq_len=5, batch_size=2, device=torch.device('cpu'))
>>> # mask shape will be (5, 5) with float values: 0.0 for attendable positions,
>>> # float('-inf') for masked positions

BaseAttentionMask

Bases: Registrable, ABC

Base class for attention mask generation.

Attention masks control which positions can attend to which other positions in transformer models.

All attention masks must return a mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len) using float values where:

  • 0.0 indicates positions that CAN be attended to
  • float('-inf') indicates positions that should NOT be attended to

This standardized format ensures compatibility with PyTorch's TransformerEncoder.mask parameter.

CausalMask

Bases: BaseAttentionMask

Generates causal (autoregressive) attention masks.

Causal masks ensure that each position can only attend to itself and previous positions, enabling autoregressive generation.

Examples:

>>> masks = CausalMask()
>>> mask = masks(seq_len=4, batch_size=1, device=torch.device('cpu'))
>>> # mask[i, j] = 0.0 if j <= i else float('-inf')

SlidingWindowAttentionMask

SlidingWindowAttentionMask(window_size)

Bases: BaseAttentionMask

Sliding window attention mask.

Restricts attention to a local window around each position, enabling efficient processing of long sequences. Commonly used in models like Longformer and Mistral.

PARAMETER DESCRIPTION
window_size

Size of the attention window on each side. Total window is (2 * window_size + 1) centered on each position.

TYPE: int

Examples:

>>> # Window size of 1 means each position can attend to itself and
>>> # one position on each side
>>> mask_gen = SlidingWindowAttentionMask(window_size=1)
>>> mask = mask_gen(seq_len=4, batch_size=1, device=torch.device('cpu'))
>>> # Position 0: can attend to [0, 1]
>>> # Position 1: can attend to [0, 1, 2]
>>> # Position 2: can attend to [1, 2, 3]
>>> # Position 3: can attend to [2, 3]
Source code in src/formed/integrations/torch/modules/masks.py
148
149
150
151
def __init__(self, window_size: int) -> None:
    if window_size < 0:
        raise ValueError(f"window_size must be non-negative, got {window_size}")
    self.window_size = window_size

window_size instance-attribute

window_size = window_size

CombinedMask

CombinedMask(masks)

Bases: BaseAttentionMask

Combines multiple attention masks.

Applies multiple masks in sequence and combines their results. A position is masked if ANY mask blocks it (logical OR for -inf values).

PARAMETER DESCRIPTION
masks

List of attention masks to combine.

TYPE: Sequence[BaseAttentionMask]

Examples:

>>> # Combine multiple structural masks
>>> mask1 = CausalMask()
>>> mask2 = SomeOtherMask()
>>> combined = CombinedMask(masks=[mask1, mask2])
Source code in src/formed/integrations/torch/modules/masks.py
202
203
def __init__(self, masks: Sequence[BaseAttentionMask]) -> None:
    self.masks = masks

masks instance-attribute

masks = masks

formed.integrations.torch.modules.samplers

Label samplers for classification tasks.

This module provides samplers that convert model logits into discrete labels.

Key Components
  • BaseLabelSampler: Abstract base class for label samplers
  • ArgmaxLabelSampler: Selects the label with highest logit
  • MultinomialLabelSampler: Samples from categorical distribution
  • BaseMultilabelSampler: Abstract base class for multilabel samplers
  • ThresholdMultilabelSampler: Selects labels above a threshold
  • TopKMultilabelSampler: Selects top-k labels
  • BernoulliMultilabelSampler: Samples labels from independent Bernoulli distributions

Examples:

>>> from formed.integrations.torch.modules import ArgmaxLabelSampler, MultinomialLabelSampler
>>> import torch
>>>
>>> logits = torch.randn(4, 10)  # (batch_size, num_classes)
>>>
>>> # Argmax sampling (deterministic)
>>> argmax_sampler = ArgmaxLabelSampler()
>>> labels = argmax_sampler(logits)
>>>
>>> # Multinomial sampling (stochastic)
>>> multi_sampler = MultinomialLabelSampler()
>>> labels = multi_sampler(logits, temperature=0.8)

BaseLabelSampler

Bases: Module, Registrable, Generic[_ParamsT], ABC

Abstract base class for label samplers.

A LabelSampler defines a strategy for sampling labels based on model logits.

CLASS TYPE PARAMETER DESCRIPTION
_ParamsT

Type of additional parameters used during sampling.

forward abstractmethod

forward(logits, params=None)

Sample labels from logits.

PARAMETER DESCRIPTION
logits

Model output logits of shape (..., num_classes).

TYPE: Tensor

params

Additional parameters for sampling.

TYPE: Optional[_ParamsT] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Sampled labels of shape (...).

Source code in src/formed/integrations/torch/modules/samplers.py
52
53
54
55
56
57
58
59
60
61
62
63
64
@abc.abstractmethod
def forward(self, logits: torch.Tensor, params: Optional[_ParamsT] = None) -> torch.Tensor:
    """Sample labels from logits.

    Args:
        logits: Model output logits of shape `(..., num_classes)`.
        params: Additional parameters for sampling.

    Returns:
        Sampled labels of shape `(...)`.

    """
    raise NotImplementedError

ArgmaxLabelSampler

Bases: BaseLabelSampler[None]

Label sampler that selects the label with the highest logit.

Examples:

>>> sampler = ArgmaxLabelSampler()
>>> logits = torch.randn(4, 10)
>>> labels = sampler(logits)  # Shape: (4,)

forward

forward(logits, params=None)

Select the argmax label.

PARAMETER DESCRIPTION
logits

Logits of shape (..., num_classes).

TYPE: Tensor

params

Ignored.

TYPE: None DEFAULT: None

RETURNS DESCRIPTION
Tensor

Labels of shape (...).

Source code in src/formed/integrations/torch/modules/samplers.py
81
82
83
84
85
86
87
88
89
90
91
92
def forward(self, logits: torch.Tensor, params: None = None) -> torch.Tensor:
    """Select the argmax label.

    Args:
        logits: Logits of shape `(..., num_classes)`.
        params: Ignored.

    Returns:
        Labels of shape `(...)`.

    """
    return logits.argmax(dim=-1)

MultinomialLabelSamplerParams

Bases: TypedDict

Parameters for MultinomialLabelSampler.

ATTRIBUTE DESCRIPTION
temperature

Sampling temperature to control randomness. Higher temperature = more random, lower = more deterministic.

TYPE: float

temperature instance-attribute

temperature

MultinomialLabelSampler

Bases: BaseLabelSampler[MultinomialLabelSamplerParams]

Label sampler that samples labels from a multinomial distribution.

Examples:

>>> sampler = MultinomialLabelSampler()
>>> logits = torch.randn(4, 10)
>>>
>>> # Sample with default temperature
>>> labels = sampler(logits)
>>>
>>> # Sample with temperature scaling
>>> labels = sampler(logits, temperature=0.5)

forward

forward(logits, params=None)

Sample labels from categorical distribution.

PARAMETER DESCRIPTION
logits

Logits of shape (..., num_classes).

TYPE: Tensor

params

Optional parameters containing temperature for sampling.

TYPE: Optional[MultinomialLabelSamplerParams] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Sampled labels of shape (...).

Source code in src/formed/integrations/torch/modules/samplers.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def forward(self, logits: torch.Tensor, params: Optional[MultinomialLabelSamplerParams] = None) -> torch.Tensor:
    """Sample labels from categorical distribution.

    Args:
        logits: Logits of shape `(..., num_classes)`.
        params: Optional parameters containing temperature for sampling.

    Returns:
        Sampled labels of shape `(...)`.

    """
    temperature = params.get("temperature", 1.0) if params is not None else 1.0
    if temperature != 1.0:
        logits = logits / temperature

    probs = F.softmax(logits, dim=-1)
    return torch.multinomial(probs.view(-1, probs.shape[-1]), num_samples=1).view(probs.shape[:-1])

BaseMultilabelSampler

Bases: Module, Registrable, Generic[_ParamsT], ABC

Abstract base class for multilabel samplers.

A MultilabelSampler defines a strategy for sampling multiple labels based on model logits.

CLASS TYPE PARAMETER DESCRIPTION
_ParamsT

Type of additional parameters used during sampling.

forward abstractmethod

forward(logits, params=None)

Sample multiple labels from logits.

PARAMETER DESCRIPTION
logits

Model output logits of shape (..., num_classes).

TYPE: Tensor

params

Additional parameters for sampling.

TYPE: Optional[_ParamsT] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Sampled labels of shape (..., num_labels).

Source code in src/formed/integrations/torch/modules/samplers.py
153
154
155
156
157
158
159
160
161
162
163
164
165
@abc.abstractmethod
def forward(self, logits: torch.Tensor, params: Optional[_ParamsT] = None) -> torch.Tensor:
    """Sample multiple labels from logits.

    Args:
        logits: Model output logits of shape `(..., num_classes)`.
        params: Additional parameters for sampling.

    Returns:
        Sampled labels of shape `(..., num_labels)`.

    """
    raise NotImplementedError

ThresholdMultilabelSamplerParams

Bases: TypedDict

Parameters for ThresholdMultilabelSampler.

ATTRIBUTE DESCRIPTION
threshold

Probability threshold for selecting labels.

TYPE: float

threshold instance-attribute

threshold

ThresholdMultilabelSampler

ThresholdMultilabelSampler(threshold=0.5)

Bases: BaseMultilabelSampler[ThresholdMultilabelSamplerParams]

Multilabel sampler that selects labels above a certain threshold.

Examples:

>>> sampler = ThresholdMultilabelSampler(threshold=0.5)
>>> logits = torch.randn(4, 10)
>>> labels = sampler(logits)  # Shape: (4, num_labels)
Source code in src/formed/integrations/torch/modules/samplers.py
193
194
195
def __init__(self, threshold: float = 0.5) -> None:
    super().__init__()
    self.threshold = threshold

threshold instance-attribute

threshold = threshold

forward

forward(logits, params=None)

Select labels above the threshold.

PARAMETER DESCRIPTION
logits

Logits of shape (..., num_classes).

TYPE: Tensor

params

Optional parameters containing threshold.

TYPE: Optional[ThresholdMultilabelSamplerParams] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Labels of shape (..., num_labels).

Source code in src/formed/integrations/torch/modules/samplers.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def forward(
    self,
    logits: torch.Tensor,
    params: Optional[ThresholdMultilabelSamplerParams] = None,
) -> torch.Tensor:
    """Select labels above the threshold.

    Args:
        logits: Logits of shape `(..., num_classes)`.
        params: Optional parameters containing threshold.

    Returns:
        Labels of shape `(..., num_labels)`.

    """
    threshold = (params or {}).get("threshold", self.threshold)
    probs = torch.sigmoid(logits)
    return (probs >= threshold).float()

TopKMultilabelSamplerParams

Bases: TypedDict

Parameters for TopKMultilabelSampler.

ATTRIBUTE DESCRIPTION
k

Number of top labels to select.

TYPE: int

k instance-attribute

k

TopKMultilabelSampler

TopKMultilabelSampler(k=1)

Bases: BaseMultilabelSampler[TopKMultilabelSamplerParams]

Multilabel sampler that selects the top-k labels.

Examples:

>>> sampler = TopKMultilabelSampler(k=3)
>>> logits = torch.randn(4, 10)
>>> labels = sampler(logits)  # Shape: (4, num_labels)
Source code in src/formed/integrations/torch/modules/samplers.py
239
240
241
def __init__(self, k: int = 1) -> None:
    super().__init__()
    self.k = k

k instance-attribute

k = k

forward

forward(logits, params=None)

Select the top-k labels.

PARAMETER DESCRIPTION
logits

Logits of shape (..., num_classes).

TYPE: Tensor

params

Optional parameters containing k for top-k selection.

TYPE: Optional[TopKMultilabelSamplerParams] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Labels of shape (..., num_labels).

Source code in src/formed/integrations/torch/modules/samplers.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def forward(
    self,
    logits: torch.Tensor,
    params: Optional[TopKMultilabelSamplerParams] = None,
) -> torch.Tensor:
    """Select the top-k labels.

    Args:
        logits: Logits of shape `(..., num_classes)`.
        params: Optional parameters containing k for top-k selection.

    Returns:
        Labels of shape `(..., num_labels)`.

    """
    k = (params or {}).get("k", self.k)
    topk_indices = logits.topk(k, dim=-1).indices
    labels = torch.zeros_like(logits).scatter_(-1, topk_indices, 1.0)
    return labels

BernoulliMultilabelSampler

Bases: BaseMultilabelSampler[None]

Multilabel sampler that samples labels from independent Bernoulli distributions.

Examples:

>>> sampler = BernoulliMultilabelSampler()
>>> logits = torch.randn(4, 10)
>>> labels = sampler(logits)  # Shape: (4, num_labels)

forward

forward(logits, params=None)

Sample labels from Bernoulli distributions.

PARAMETER DESCRIPTION
logits

Logits of shape (..., num_classes).

TYPE: Tensor

params

Ignored.

TYPE: None DEFAULT: None

RETURNS DESCRIPTION
Tensor

Sampled labels of shape (..., num_labels).

Source code in src/formed/integrations/torch/modules/samplers.py
275
276
277
278
279
280
281
282
283
284
285
286
287
def forward(self, logits: torch.Tensor, params: None = None) -> torch.Tensor:
    """Sample labels from Bernoulli distributions.

    Args:
        logits: Logits of shape `(..., num_classes)`.
        params: Ignored.

    Returns:
        Sampled labels of shape `(..., num_labels)`.

    """
    probs = torch.sigmoid(logits)
    return torch.bernoulli(probs)

formed.integrations.torch.modules.scalarmix

ScalarMix

ScalarMix(
    mixture_size,
    do_layer_norm=False,
    initial_scalar_parameters=None,
    trainable=True,
)

Bases: Module

Computes a parameterised scalar mixture of N tensors, mixture = gamma * sum(s_k * tensor_k) where s = softmax(w), with w and gamma scalar parameters. In addition, if do_layer_norm=True then apply layer normalization to each tensor before weighting.

Note

This script is based on the AllenNLP implementation of ScalarMix: https://github.com/allenai/allennlp/blob/v2.10.0/allennlp/modules/scalar_mix.py

Source code in src/formed/integrations/torch/modules/scalarmix.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def __init__(
    self,
    mixture_size: int,
    do_layer_norm: bool = False,
    initial_scalar_parameters: Sequence[float] | None = None,
    trainable: bool = True,
) -> None:
    super().__init__()
    self.mixture_size = mixture_size
    self.do_layer_norm = do_layer_norm

    if initial_scalar_parameters is None:
        initial_scalar_parameters = [0.0] * mixture_size
    elif len(initial_scalar_parameters) != mixture_size:
        raise ValueError(
            "Length of initial_scalar_parameters {} differs from mixture_size {}".format(
                initial_scalar_parameters, mixture_size
            )
        )

    self.scalar_parameters = torch.nn.ParameterList(
        [
            torch.nn.Parameter(torch.FloatTensor([initial_scalar_parameters[i]]), requires_grad=trainable)
            for i in range(mixture_size)
        ]
    )
    self.gamma = torch.nn.Parameter(torch.FloatTensor([1.0]), requires_grad=trainable)

mixture_size instance-attribute

mixture_size = mixture_size

do_layer_norm instance-attribute

do_layer_norm = do_layer_norm

scalar_parameters instance-attribute

scalar_parameters = ParameterList(
    [
        (
            Parameter(
                FloatTensor(
                    [initial_scalar_parameters[i]]
                ),
                requires_grad=trainable,
            )
        )
        for i in (range(mixture_size))
    ]
)

gamma instance-attribute

gamma = Parameter(
    FloatTensor([1.0]), requires_grad=trainable
)

forward

forward(tensors, mask=None)

Compute a weighted average of the tensors. The input tensors an be any shape with at least two dimensions, but must all be the same shape. When do_layer_norm=True, the mask is required input. If the tensors are dimensioned (dim_0, ..., dim_{n-1}, dim_n), then the mask is dimensioned (dim_0, ..., dim_{n-1}), as in the typical case with tensors of shape (batch_size, timesteps, dim) and mask of shape (batch_size, timesteps). When do_layer_norm=False the mask is ignored.

Source code in src/formed/integrations/torch/modules/scalarmix.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def forward(self, tensors: Sequence[torch.Tensor], mask: torch.Tensor | None = None) -> torch.Tensor:
    """
    Compute a weighted average of the `tensors`.  The input tensors an be any shape
    with at least two dimensions, but must all be the same shape.
    When `do_layer_norm=True`, the `mask` is required input.  If the `tensors` are
    dimensioned  `(dim_0, ..., dim_{n-1}, dim_n)`, then the `mask` is dimensioned
    `(dim_0, ..., dim_{n-1})`, as in the typical case with `tensors` of shape
    `(batch_size, timesteps, dim)` and `mask` of shape `(batch_size, timesteps)`.
    When `do_layer_norm=False` the `mask` is ignored.
    """
    if len(tensors) != self.mixture_size:
        raise ValueError(
            "{} tensors were passed, but the module was initialized to mix {} tensors.".format(
                len(tensors), self.mixture_size
            )
        )

    def _do_layer_norm(
        tensor: torch.Tensor,
        broadcast_mask: torch.Tensor,
        num_elements_not_masked: torch.Tensor,
    ) -> torch.Tensor:
        tensor_masked = tensor * broadcast_mask
        mean = torch.sum(tensor_masked) / num_elements_not_masked
        variance = torch.sum(((tensor_masked - mean) * broadcast_mask) ** 2) / num_elements_not_masked
        return (tensor - mean) / torch.sqrt(variance + 1e-13)

    normed_weights = torch.split(
        torch.nn.functional.softmax(torch.cat([parameter for parameter in self.scalar_parameters]), dim=0),
        split_size_or_sections=1,
    )

    pieces: list[torch.Tensor]
    if not self.do_layer_norm:
        pieces = [weight * tensor for weight, tensor in zip(normed_weights, tensors)]
        return self.gamma * sum(pieces)

    else:
        assert mask is not None
        broadcast_mask = mask.unsqueeze(-1)
        input_dim = tensors[0].size(-1)
        num_elements_not_masked = torch.sum(mask) * input_dim

        pieces = [
            weight * _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked)
            for weight, tensor in zip(normed_weights, tensors)
        ]
        return self.gamma * sum(pieces)

formed.integrations.torch.modules.vectorizers

Sequence vectorization modules for PyTorch models.

This module provides vectorizers that convert variable-length sequences into fixed-size vectors. Vectorizers apply pooling operations over the sequence dimension to produce single vectors per sequence.

Key Components
  • BaseSequenceVectorizer: Abstract base class for vectorizers
  • BagOfEmbeddingsSequenceVectorizer: Pools sequence embeddings
Features
  • Multiple pooling strategies (mean, max, min, sum, first, last, hier)
  • Masked pooling to ignore padding tokens
  • Optional normalization before pooling
  • Hierarchical pooling with sliding windows

Examples:

>>> from formed.integrations.torch.modules import BagOfEmbeddingsSequenceVectorizer
>>>
>>> # Mean pooling over sequence
>>> vectorizer = BagOfEmbeddingsSequenceVectorizer(pooling="mean")
>>> vector = vectorizer(embeddings, mask=mask)
>>>
>>> # Max pooling with normalization
>>> vectorizer = BagOfEmbeddingsSequenceVectorizer(
...     pooling="max",
...     normalize=True
... )

BaseSequenceVectorizer

Bases: Module, Registrable, ABC

Abstract base class for sequence vectorizers.

Vectorizers convert variable-length sequences into fixed-size vectors by applying pooling operations over the sequence dimension.

forward abstractmethod

forward(inputs, *, mask=None)

Vectorize a sequence into a fixed-size vector.

PARAMETER DESCRIPTION
inputs

Input embeddings of shape (batch_size, seq_len, embedding_dim).

TYPE: Tensor

mask

Optional attention mask of shape (batch_size, seq_len).

TYPE: Optional[Tensor] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Vectorized output of shape (batch_size, output_dim).

Source code in src/formed/integrations/torch/modules/vectorizers.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
@abc.abstractmethod
def forward(
    self,
    inputs: torch.Tensor,
    *,
    mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Vectorize a sequence into a fixed-size vector.

    Args:
        inputs: Input embeddings of shape `(batch_size, seq_len, embedding_dim)`.
        mask: Optional attention mask of shape `(batch_size, seq_len)`.

    Returns:
        Vectorized output of shape `(batch_size, output_dim)`.

    """
    raise NotImplementedError

get_input_dim abstractmethod

get_input_dim()

Get the expected input dimension.

RETURNS DESCRIPTION
Optional[int]

Input dimension or None if dimension-agnostic.

Source code in src/formed/integrations/torch/modules/vectorizers.py
70
71
72
73
74
75
76
77
78
@abc.abstractmethod
def get_input_dim(self) -> Optional[int]:
    """Get the expected input dimension.

    Returns:
        Input dimension or None if dimension-agnostic.

    """
    raise NotImplementedError

get_output_dim abstractmethod

get_output_dim()

Get the output dimension.

RETURNS DESCRIPTION
Union[int, Callable[[int], int]]

Output feature dimension or a function mapping input dim to output dim.

Source code in src/formed/integrations/torch/modules/vectorizers.py
80
81
82
83
84
85
86
87
88
@abc.abstractmethod
def get_output_dim(self) -> Union[int, Callable[[int], int]]:
    """Get the output dimension.

    Returns:
        Output feature dimension or a function mapping input dim to output dim.

    """
    raise NotImplementedError

BagOfEmbeddingsSequenceVectorizer

BagOfEmbeddingsSequenceVectorizer(
    pooling="mean", normalize=False, window_size=None
)

Bases: BaseSequenceVectorizer

Bag-of-embeddings vectorizer using pooling operations.

This vectorizer applies pooling over the sequence dimension to create fixed-size vectors. Multiple pooling strategies are supported, and padding tokens are properly masked during pooling.

PARAMETER DESCRIPTION
pooling

Pooling strategy to use: - "mean": Average pooling (default) - "max": Max pooling - "min": Min pooling - "sum": Sum pooling - "first": Take first token - "last": Take last non-padding token - "hier": Hierarchical pooling with sliding window

TYPE: Union[PoolingMethod, Sequence[PoolingMethod]] DEFAULT: 'mean'

normalize

Whether to L2-normalize embeddings before pooling.

TYPE: bool DEFAULT: False

window_size

Window size for hierarchical pooling (required if pooling="hier").

TYPE: Optional[int] DEFAULT: None

Examples:

>>> # Mean pooling
>>> vectorizer = BagOfEmbeddingsSequenceVectorizer(pooling="mean")
>>> vector = vectorizer(embeddings, mask=mask)
>>>
>>> # Max pooling with normalization
>>> vectorizer = BagOfEmbeddingsSequenceVectorizer(
...     pooling="max",
...     normalize=True
... )
>>>
>>> # Multiple pooling methods combined
>>> vectorizer = BagOfEmbeddingsSequenceVectorizer(
...     pooling=["mean", "max"]
... )
>>>
>>> # Hierarchical pooling
>>> vectorizer = BagOfEmbeddingsSequenceVectorizer(
...     pooling="hier",
...     window_size=3
... )
Note

This vectorizer is dimension-agnostic - it preserves the embedding dimension from input to output (multiplied by number of pooling methods).

Source code in src/formed/integrations/torch/modules/vectorizers.py
148
149
150
151
152
153
154
155
156
157
def __init__(
    self,
    pooling: Union[PoolingMethod, Sequence[PoolingMethod]] = "mean",
    normalize: bool = False,
    window_size: Optional[int] = None,
) -> None:
    super().__init__()
    self._pooling: Union[PoolingMethod, Sequence[PoolingMethod]] = pooling
    self._normalize = normalize
    self._window_size = window_size

forward

forward(inputs, *, mask=None)

Vectorize sequence using bag-of-embeddings pooling.

PARAMETER DESCRIPTION
inputs

Input embeddings of shape (batch_size, seq_len, input_dim).

TYPE: Tensor

mask

Optional attention mask of shape (batch_size, seq_len). True indicates valid positions, False indicates padding.

TYPE: Optional[Tensor] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Vectorized output of shape (batch_size, output_dim).

Tensor

If multiple pooling methods are used, output_dim = input_dim * num_pooling.

Source code in src/formed/integrations/torch/modules/vectorizers.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def forward(
    self,
    inputs: torch.Tensor,
    *,
    mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Vectorize sequence using bag-of-embeddings pooling.

    Args:
        inputs: Input embeddings of shape `(batch_size, seq_len, input_dim)`.
        mask: Optional attention mask of shape `(batch_size, seq_len)`.
             True indicates valid positions, False indicates padding.

    Returns:
        Vectorized output of shape `(batch_size, output_dim)`.
        If multiple pooling methods are used, output_dim = input_dim * num_pooling.

    """
    return masked_pool(
        inputs,
        mask=mask,
        pooling=self._pooling,
        normalize=self._normalize,
        window_size=self._window_size,
    )

get_input_dim

get_input_dim()

Get the expected input dimension.

RETURNS DESCRIPTION
None

None (dimension-agnostic vectorizer).

Source code in src/formed/integrations/torch/modules/vectorizers.py
185
186
187
188
189
190
191
192
def get_input_dim(self) -> None:
    """Get the expected input dimension.

    Returns:
        None (dimension-agnostic vectorizer).

    """
    return None

get_output_dim

get_output_dim()

Get the output dimension.

RETURNS DESCRIPTION
Callable[[int], int]

Function mapping input dimension to output dimension.

Callable[[int], int]

Output dimension = input_dim * number_of_pooling_methods.

Source code in src/formed/integrations/torch/modules/vectorizers.py
194
195
196
197
198
199
200
201
202
203
def get_output_dim(self) -> Callable[[int], int]:
    """Get the output dimension.

    Returns:
        Function mapping input dimension to output dimension.
        Output dimension = input_dim * number_of_pooling_methods.

    """
    num_pooling = 1 if isinstance(self._pooling, str) else len(self._pooling)
    return lambda input_dim: input_dim * num_pooling

CnnSequenceVectorizer

CnnSequenceVectorizer(
    input_dim,
    num_filters,
    ngram_filter_sizes=(2, 3, 4, 5),
    conv_layer_activation=None,
    output_dim=None,
)

Bases: BaseSequenceVectorizer

CNN-based sequence vectorizer using multiple n-gram filters.

This vectorizer applies multiple 1D convolutions with different kernel sizes (n-gram filters) to capture local patterns of different lengths. Max pooling is applied over each filter's output to create a fixed-size representation.

Based on "A Sensitivity Analysis of (and Practitioners' Guide to) Convolutional Neural Networks for Sentence Classification" by Zhang and Wallace (2016).

PARAMETER DESCRIPTION
input_dim

Input embedding dimension.

TYPE: int

num_filters

Number of filters for each n-gram size.

TYPE: int

ngram_filter_sizes

Tuple of n-gram sizes for convolution filters. Default is (2, 3, 4, 5) for bigrams through 5-grams.

TYPE: Sequence[int] DEFAULT: (2, 3, 4, 5)

conv_layer_activation

Activation function after convolution. Default is ReLU.

TYPE: Optional[Callable[[Tensor], Tensor]] DEFAULT: None

output_dim

Optional output dimension. If provided, applies linear projection after concatenating filter outputs.

TYPE: Optional[int] DEFAULT: None

Examples:

>>> from formed.integrations.torch.modules.vectorizers import CnnSequenceVectorizer
>>>
>>> # Standard CNN with multiple n-gram filters
>>> vectorizer = CnnSequenceVectorizer(
...     input_dim=128,
...     num_filters=100,
...     ngram_filter_sizes=(2, 3, 4, 5)
... )
>>> # Output dim = 100 * 4 = 400
>>>
>>> # With output projection
>>> vectorizer = CnnSequenceVectorizer(
...     input_dim=128,
...     num_filters=100,
...     ngram_filter_sizes=(3, 4, 5),
...     output_dim=256
... )
>>>
>>> # Custom activation
>>> import torch.nn as nn
>>> vectorizer = CnnSequenceVectorizer(
...     input_dim=128,
...     num_filters=50,
...     ngram_filter_sizes=(2, 3),
...     conv_layer_activation=nn.Tanh()
... )
Note
  • Properly handles padding masks to avoid max-pooling over padding positions
  • Output dimension without projection: num_filters * len(ngram_filter_sizes)
  • Each filter extracts patterns of a specific n-gram size
Source code in src/formed/integrations/torch/modules/vectorizers.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
def __init__(
    self,
    input_dim: int,
    num_filters: int,
    ngram_filter_sizes: Sequence[int] = (2, 3, 4, 5),
    conv_layer_activation: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
    output_dim: Optional[int] = None,
) -> None:
    super().__init__()
    self._input_dim = input_dim
    self._num_filters = num_filters
    self._ngram_filter_sizes = ngram_filter_sizes
    self._activation = conv_layer_activation or torch.nn.ReLU()

    self._convolution_layers = [
        torch.nn.Conv1d(
            in_channels=self._input_dim,
            out_channels=self._num_filters,
            kernel_size=ngram_size,
        )
        for ngram_size in self._ngram_filter_sizes
    ]
    for i, conv_layer in enumerate(self._convolution_layers):
        self.add_module("conv_layer_%d" % i, conv_layer)

    maxpool_output_dim = self._num_filters * len(self._ngram_filter_sizes)
    self.projection_layer: Optional[torch.nn.Linear]
    self._output_dim: int
    if output_dim:
        self.projection_layer = torch.nn.Linear(maxpool_output_dim, output_dim)
        self._output_dim = output_dim
    else:
        self.projection_layer = None
        self._output_dim = maxpool_output_dim

projection_layer instance-attribute

projection_layer

get_input_dim

get_input_dim()

Get the expected input dimension.

RETURNS DESCRIPTION
int

Input embedding dimension.

Source code in src/formed/integrations/torch/modules/vectorizers.py
296
297
298
299
300
301
302
303
def get_input_dim(self) -> int:
    """Get the expected input dimension.

    Returns:
        Input embedding dimension.

    """
    return self._input_dim

get_output_dim

get_output_dim()

Get the output dimension.

RETURNS DESCRIPTION
int

Output vector dimension (num_filters * len(ngram_filter_sizes) or custom output_dim).

Source code in src/formed/integrations/torch/modules/vectorizers.py
305
306
307
308
309
310
311
312
def get_output_dim(self) -> int:
    """Get the output dimension.

    Returns:
        Output vector dimension (num_filters * len(ngram_filter_sizes) or custom output_dim).

    """
    return self._output_dim

forward

forward(inputs, *, mask=None)

Vectorize sequence using CNN with multiple n-gram filters.

PARAMETER DESCRIPTION
inputs

Input embeddings of shape (batch_size, seq_len, input_dim).

TYPE: Tensor

mask

Optional attention mask of shape (batch_size, seq_len). True indicates valid positions, False indicates padding.

TYPE: Optional[Tensor] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Vectorized output of shape (batch_size, output_dim).

Source code in src/formed/integrations/torch/modules/vectorizers.py
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
def forward(
    self,
    inputs: torch.Tensor,
    *,
    mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Vectorize sequence using CNN with multiple n-gram filters.

    Args:
        inputs: Input embeddings of shape `(batch_size, seq_len, input_dim)`.
        mask: Optional attention mask of shape `(batch_size, seq_len)`.
             True indicates valid positions, False indicates padding.

    Returns:
        Vectorized output of shape `(batch_size, output_dim)`.

    """
    if mask is not None:
        inputs = inputs * mask.unsqueeze(-1)
    else:
        mask = torch.ones(*inputs.size()[:-1], device=inputs.device).bool()

    inputs = torch.transpose(inputs, 1, 2)

    filter_outputs = []
    batch_size = inputs.shape[0]
    last_unmasked_inputs = mask.sum(dim=1).unsqueeze(dim=-1)  # Shape: (batch_size, 1)
    for i in range(len(self._convolution_layers)):
        convolution_layer = getattr(self, "conv_layer_{}".format(i))
        pool_length = inputs.shape[2] - convolution_layer.kernel_size[0] + 1

        activations = self._activation(convolution_layer(inputs))

        indices = (
            torch.arange(pool_length, device=activations.device).unsqueeze(0).expand(batch_size, pool_length)
        )  # Shape: (batch_size, pool_length)
        activations_mask = indices.ge(
            last_unmasked_inputs - convolution_layer.kernel_size[0] + 1
        )  # Shape: (batch_size, pool_length)
        activations_mask = activations_mask.unsqueeze(1).expand_as(
            activations
        )  # Shape: (batch_size, num_filters, pool_length)

        activations = activations + (
            activations_mask * min_value_of_dtype(activations.dtype)
        )  # Shape: (batch_size, pool_length)

        # Pick out the max filters
        filter_outputs.append(activations.max(dim=2)[0])

    maxpool_output = torch.cat(filter_outputs, dim=1) if len(filter_outputs) > 1 else filter_outputs[0]
    maxpool_output[maxpool_output == min_value_of_dtype(maxpool_output.dtype)] = 0.0

    if self.projection_layer:
        result = self.projection_layer(maxpool_output)
    else:
        result = maxpool_output
    return result

SelfAttentiveSequenceVectorizer

SelfAttentiveSequenceVectorizer(
    input_dim, num_heads=1, hidden_dims=()
)

Bases: BaseSequenceVectorizer

Self-attentive sequence vectorizer using learned attention weights.

This vectorizer uses learned attention mechanisms to compute weighted averages of sequence embeddings. Multiple attention heads can be used to capture different aspects of the sequence.

Based on "A Structured Self-attentive Sentence Embedding" by Lin et al. (2017).

PARAMETER DESCRIPTION
input_dim

Input embedding dimension. Must be divisible by num_heads.

TYPE: int

num_heads

Number of attention heads. Each head learns different attention patterns.

TYPE: int DEFAULT: 1

hidden_dims

Hidden dimensions for the attention scoring network. Empty tuple means direct scoring without hidden layers.

TYPE: Sequence[int] DEFAULT: ()

Examples:

>>> from formed.integrations.torch.modules.vectorizers import (
...     SelfAttentiveSequenceVectorizer
... )
>>>
>>> # Single attention head
>>> vectorizer = SelfAttentiveSequenceVectorizer(
...     input_dim=128,
...     num_heads=1
... )
>>>
>>> # Multiple attention heads
>>> vectorizer = SelfAttentiveSequenceVectorizer(
...     input_dim=128,
...     num_heads=4
... )
>>>
>>> # With hidden layers in attention scorer
>>> vectorizer = SelfAttentiveSequenceVectorizer(
...     input_dim=128,
...     num_heads=2,
...     hidden_dims=(64,)
... )
Note
  • Each attention head operates on input_dim // num_heads dimensions
  • Outputs are concatenated across heads to preserve input dimension
  • Properly handles padding masks via masked softmax
Source code in src/formed/integrations/torch/modules/vectorizers.py
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
def __init__(
    self,
    input_dim: int,
    num_heads: int = 1,
    hidden_dims: Sequence[int] = (),
) -> None:
    assert input_dim % num_heads == 0, "Input dimension must be divisible by number of heads."

    super().__init__()
    self._input_dim = input_dim
    self._num_heads = num_heads
    self._head_dim = input_dim // num_heads

    self._scorers = torch.nn.ModuleList(
        [
            torch.nn.Sequential(
                FeedForward(
                    input_dim=self._head_dim,
                    hidden_dims=hidden_dims,
                ),
                torch.nn.Linear(hidden_dims[-1], 1),
            )
            if hidden_dims
            else torch.nn.Linear(self._head_dim, 1)
            for _ in range(num_heads)
        ]
    )

get_input_dim

get_input_dim()

Get the expected input dimension.

RETURNS DESCRIPTION
int

Input embedding dimension.

Source code in src/formed/integrations/torch/modules/vectorizers.py
449
450
451
452
453
454
455
456
def get_input_dim(self) -> int:
    """Get the expected input dimension.

    Returns:
        Input embedding dimension.

    """
    return self._input_dim

get_output_dim

get_output_dim()

Get the output dimension.

RETURNS DESCRIPTION
int

Output dimension (same as input dimension).

Source code in src/formed/integrations/torch/modules/vectorizers.py
458
459
460
461
462
463
464
465
def get_output_dim(self) -> int:
    """Get the output dimension.

    Returns:
        Output dimension (same as input dimension).

    """
    return self._input_dim

forward

forward(inputs, *, mask=None)

Vectorize sequence using self-attention.

PARAMETER DESCRIPTION
inputs

Input embeddings of shape (batch_size, seq_len, input_dim).

TYPE: Tensor

mask

Optional attention mask of shape (batch_size, seq_len). True indicates valid positions, False indicates padding.

TYPE: Tensor | None DEFAULT: None

RETURNS DESCRIPTION
Tensor

Vectorized output of shape (batch_size, input_dim).

Source code in src/formed/integrations/torch/modules/vectorizers.py
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
def forward(
    self,
    inputs: torch.Tensor,
    *,
    mask: torch.Tensor | None = None,
) -> torch.Tensor:
    """Vectorize sequence using self-attention.

    Args:
        inputs: Input embeddings of shape `(batch_size, seq_len, input_dim)`.
        mask: Optional attention mask of shape `(batch_size, seq_len)`.
             True indicates valid positions, False indicates padding.

    Returns:
        Vectorized output of shape `(batch_size, input_dim)`.

    """
    if mask is None:
        mask = inputs.new_ones(inputs.size()[:-1], dtype=torch.bool)

    if mask.dim() == 2:
        # Shape: (batch_size, seq_len, 1)
        mask = mask.unsqueeze(-1)

    inputs = inputs * mask.float()

    head_outputs = []
    for head_index in range(self._num_heads):
        scorer = self._scorers[head_index]

        # Shape: (batch_size, seq_len, head_dim)
        head_input = inputs[..., head_index * self._head_dim : (head_index + 1) * self._head_dim]
        # Shape: (batch_size, seq_len, 1)
        attn_weights = masked_softmax(scorer(head_input), mask, dim=1)

        # Shape: (batch_size, head_dim)
        head_output = (attn_weights * head_input).sum(dim=1)
        head_outputs.append(head_output)

    # Shape: (batch_size, input_dim)
    output = torch.cat(head_outputs, dim=-1)
    return output

ConcatSequenceVectorizer

ConcatSequenceVectorizer(vectorizers)

Bases: BaseSequenceVectorizer

Concatenates outputs from multiple sequence vectorizers.

Applies multiple vectorizers to the same input sequence and concatenates their outputs along the feature dimension. This allows combining different vectorization strategies (e.g., mean pooling + max pooling + attention).

PARAMETER DESCRIPTION
vectorizers

List of vectorizers to apply in parallel. All vectorizers receive the same input sequence.

TYPE: Sequence[BaseSequenceVectorizer]

Examples:

>>> from formed.integrations.torch.modules.vectorizers import (
...     ConcatSequenceVectorizer,
...     BagOfEmbeddingsSequenceVectorizer,
...     SelfAttentiveSequenceVectorizer
... )
>>>
>>> # Combine mean pooling and max pooling
>>> vectorizers = [
...     BagOfEmbeddingsSequenceVectorizer(pooling="mean"),
...     BagOfEmbeddingsSequenceVectorizer(pooling="max"),
... ]
>>> vectorizer = ConcatSequenceVectorizer(vectorizers=vectorizers)
>>>
>>> # Combine pooling and attention
>>> vectorizers = [
...     BagOfEmbeddingsSequenceVectorizer(pooling="mean"),
...     SelfAttentiveSequenceVectorizer(input_dim=128, num_heads=2),
... ]
>>> vectorizer = ConcatSequenceVectorizer(vectorizers=vectorizers)
Note
  • Output dimension is the sum of all vectorizer output dimensions
  • Handles both fixed and dynamic output dimensions from vectorizers
Source code in src/formed/integrations/torch/modules/vectorizers.py
550
551
552
def __init__(self, vectorizers: Sequence[BaseSequenceVectorizer]) -> None:
    super().__init__()
    self._vectorizers = torch.nn.ModuleList(vectorizers)

get_input_dim

get_input_dim()

Get the expected input dimension.

RETURNS DESCRIPTION
int | None

First non-None input dimension from vectorizers, or None if all are None.

Source code in src/formed/integrations/torch/modules/vectorizers.py
554
555
556
557
558
559
560
561
562
def get_input_dim(self) -> int | None:
    """Get the expected input dimension.

    Returns:
        First non-None input dimension from vectorizers, or None if all are None.

    """
    input_dims = [v.get_input_dim() for v in cast(Sequence[BaseSequenceVectorizer], self._vectorizers)]
    return next((dim for dim in input_dims if dim is not None), None)

get_output_dim

get_output_dim()

Get the output dimension.

RETURNS DESCRIPTION
int | Callable[[int], int]

Sum of all vectorizer output dimensions if all are fixed integers,

int | Callable[[int], int]

otherwise a function that computes the sum given an input dimension.

Source code in src/formed/integrations/torch/modules/vectorizers.py
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
def get_output_dim(self) -> int | Callable[[int], int]:
    """Get the output dimension.

    Returns:
        Sum of all vectorizer output dimensions if all are fixed integers,
        otherwise a function that computes the sum given an input dimension.

    """
    input_dims = [v.get_output_dim() for v in cast(Sequence[BaseSequenceVectorizer], self._vectorizers)]
    if all(isinstance(dim, int) for dim in input_dims):
        return sum(cast(int, dim) for dim in input_dims)

    def _get_output_dim(input_dim: int) -> int:
        total_dim = 0
        for dim in input_dims:
            if isinstance(dim, int):
                total_dim += dim
            else:
                total_dim += dim(input_dim)
        return total_dim

    return _get_output_dim

forward

forward(inputs, *, mask=None)

Vectorize sequence by concatenating multiple vectorizer outputs.

PARAMETER DESCRIPTION
inputs

Input embeddings of shape (batch_size, seq_len, input_dim).

TYPE: Tensor

mask

Optional attention mask of shape (batch_size, seq_len).

TYPE: Tensor | None DEFAULT: None

RETURNS DESCRIPTION
Tensor

Concatenated vectors of shape (batch_size, output_dim).

Source code in src/formed/integrations/torch/modules/vectorizers.py
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
def forward(
    self,
    inputs: torch.Tensor,
    *,
    mask: torch.Tensor | None = None,
) -> torch.Tensor:
    """Vectorize sequence by concatenating multiple vectorizer outputs.

    Args:
        inputs: Input embeddings of shape `(batch_size, seq_len, input_dim)`.
        mask: Optional attention mask of shape `(batch_size, seq_len)`.

    Returns:
        Concatenated vectors of shape `(batch_size, output_dim)`.

    """
    vectors = [vectorizer(inputs, mask=mask) for vectorizer in self._vectorizers]
    return torch.cat(vectors, dim=-1)

formed.integrations.torch.modules.weighters

Label weighters for classification tasks.

This module provides weighters that assign weights to class labels, useful for handling imbalanced datasets.

Key Components
  • BaseLabelWeighter: Abstract base class for label weighters
  • StaticLabelWeighter: Uses fixed weights per class
  • BalancedByDistributionLabelWeighter: Balances based on class distribution

Examples:

>>> from formed.integrations.torch.modules import StaticLabelWeighter
>>> import torch
>>>
>>> # Static weights for 3 classes
>>> weights = torch.tensor([1.0, 2.0, 3.0])  # Weight rare classes more
>>> weighter = StaticLabelWeighter(weights=weights)
>>>
>>> logits = torch.randn(4, 3)
>>> labels = torch.tensor([0, 1, 2, 0])
>>> class_weights = weighter(logits, labels)  # Shape: (1, 3)

BaseLabelWeighter

Bases: Module, Registrable, Generic[_ParamsT], ABC

Abstract base class for label weighters.

A LabelWeighter defines a strategy for assigning weights to each label based on model logits and true targets.

CLASS TYPE PARAMETER DESCRIPTION
_ParamsT

Type of additional parameters used during weighting.

forward abstractmethod

forward(logits, targets, params=None)

Compute weights for each target label.

PARAMETER DESCRIPTION
logits

Model output logits of shape (..., num_classes).

TYPE: Tensor

targets

True target labels of shape (...).

TYPE: Tensor

params

Optional additional parameters for weighting.

TYPE: Optional[_ParamsT] DEFAULT: None

RETURNS DESCRIPTION
Tensor

Weights for each logit of shape (1, num_classes) or broadcastable to logits shape.

Source code in src/formed/integrations/torch/modules/weighters.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
@abc.abstractmethod
def forward(
    self,
    logits: torch.Tensor,
    targets: torch.Tensor,
    params: Optional[_ParamsT] = None,
) -> torch.Tensor:
    """Compute weights for each target label.

    Args:
        logits: Model output logits of shape `(..., num_classes)`.
        targets: True target labels of shape `(...)`.
        params: Optional additional parameters for weighting.

    Returns:
        Weights for each logit of shape `(1, num_classes)` or broadcastable to logits shape.

    """
    raise NotImplementedError

StaticLabelWeighter

StaticLabelWeighter(weights)

Bases: BaseLabelWeighter[None]

Label weighter that assigns static weights to each class.

PARAMETER DESCRIPTION
weights

A tensor of shape (num_classes,) containing the weight for each class.

TYPE: TensorCompatible

Examples:

>>> # Weight class 1 twice as much as class 0
>>> weights = torch.tensor([1.0, 2.0, 1.0])
>>> weighter = StaticLabelWeighter(weights=weights)
>>> logits = torch.randn(4, 3)
>>> labels = torch.tensor([0, 1, 2, 0])
>>> class_weights = weighter(logits, labels)
Source code in src/formed/integrations/torch/modules/weighters.py
87
88
89
90
def __init__(self, weights: TensorCompatible) -> None:
    super().__init__()
    self.register_buffer("_weights", ensure_torch_tensor(weights, dtype=torch.float))
    self._weights: torch.Tensor

forward

forward(logits, targets, params=None)

Return static weights.

PARAMETER DESCRIPTION
logits

Ignored.

TYPE: Tensor

targets

Ignored.

TYPE: Tensor

params

Ignored.

TYPE: None DEFAULT: None

RETURNS DESCRIPTION
Tensor

Weights of shape (1, num_classes).

Source code in src/formed/integrations/torch/modules/weighters.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def forward(
    self,
    logits: torch.Tensor,
    targets: torch.Tensor,
    params: None = None,
) -> torch.Tensor:
    """Return static weights.

    Args:
        logits: Ignored.
        targets: Ignored.
        params: Ignored.

    Returns:
        Weights of shape `(1, num_classes)`.

    """
    return self._weights.unsqueeze(0)

BalancedByDistributionLabelWeighter

BalancedByDistributionLabelWeighter(
    distribution, eps=1e-08
)

Bases: BaseLabelWeighter[None]

Label weighter that balances classes based on their distribution.

The weight for each class is computed as: 1 / (distribution * num_classes + eps)

PARAMETER DESCRIPTION
distribution

A tensor of shape (num_classes,) representing the class distribution (should sum to 1.0).

TYPE: TensorCompatible

eps

A small epsilon value to avoid division by zero.

TYPE: float DEFAULT: 1e-08

Examples:

>>> # Class distribution: 50%, 30%, 20%
>>> distribution = torch.tensor([0.5, 0.3, 0.2])
>>> weighter = BalancedByDistributionLabelWeighter(distribution=distribution)
>>> logits = torch.randn(4, 3)
>>> labels = torch.tensor([0, 1, 2, 0])
>>> class_weights = weighter(logits, labels)
>>> # Rare classes get higher weights
Source code in src/formed/integrations/torch/modules/weighters.py
134
135
136
137
138
def __init__(self, distribution: TensorCompatible, eps: float = 1e-8) -> None:
    super().__init__()
    self.register_buffer("_distribution", ensure_torch_tensor(distribution, dtype=torch.float))
    self._distribution: torch.Tensor
    self._eps = eps

forward

forward(logits, targets, params=None)

Compute balanced weights.

PARAMETER DESCRIPTION
logits

Ignored.

TYPE: Tensor

targets

Ignored.

TYPE: Tensor

params

Ignored.

TYPE: None DEFAULT: None

RETURNS DESCRIPTION
Tensor

Weights of shape (1, num_classes).

Source code in src/formed/integrations/torch/modules/weighters.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def forward(
    self,
    logits: torch.Tensor,
    targets: torch.Tensor,
    params: None = None,
) -> torch.Tensor:
    """Compute balanced weights.

    Args:
        logits: Ignored.
        targets: Ignored.
        params: Ignored.

    Returns:
        Weights of shape `(1, num_classes)`.

    """
    num_classes = len(self._distribution)
    weights = 1.0 / (self._distribution * num_classes + self._eps)
    return weights.unsqueeze(0)

formed.integrations.torch.training.callbacks

Training callbacks for monitoring and controlling PyTorch model training.

This module provides a callback system for PyTorch training, allowing custom logic to be executed at various points in the training loop. Callbacks can monitor metrics, save checkpoints, implement early stopping, and integrate with experiment tracking systems.

Key Components
  • TorchTrainingCallback: Base class for all callbacks
  • EvaluationCallback: Computes metrics using custom evaluators
  • EarlyStoppingCallback: Stops training based on metric improvements
  • MlflowCallback: Logs metrics to MLflow
Features
  • Hook points at training/epoch/batch start and end
  • Metric computation and logging
  • Model checkpointing
  • Early stopping with patience
  • MLflow integration
  • Extensible for custom callbacks

Examples:

>>> from formed.integrations.torch import (
...     TorchTrainer,
...     EarlyStoppingCallback,
...     EvaluationCallback,
...     MlflowCallback
... )
>>>
>>> trainer = TorchTrainer(
...     train_dataloader=train_loader,
...     val_dataloader=val_loader,
...     callbacks=[
...         EvaluationCallback(my_evaluator),
...         EarlyStoppingCallback(patience=5, metric="-loss"),
...         MlflowCallback()
...     ]
... )

TorchTrainingCallback

Bases: Registrable

Base class for training callbacks.

Callbacks provide hooks to execute custom logic at various points during training. Subclasses can override any hook method to implement custom behavior such as logging, checkpointing, or early stopping.

Hook execution order
  1. on_training_start - once at the beginning
  2. on_epoch_start - at the start of each epoch
  3. on_batch_start - before each training batch
  4. on_batch_end - after each training batch
  5. on_eval_start - before evaluation (returns evaluator)
  6. on_eval_end - after evaluation with computed metrics
  7. on_log - when metrics are logged
  8. on_epoch_end - at the end of each epoch
  9. on_training_end - once at the end (can modify final state)

Examples:

>>> @TorchTrainingCallback.register("my_callback")
... class MyCallback(TorchTrainingCallback):
...     def on_epoch_end(self, trainer, model, state, epoch):
...         print(f"Completed epoch {epoch} at step {state.step}")

on_training_start

on_training_start(trainer, model, state)
Source code in src/formed/integrations/torch/training/callbacks.py
86
87
88
89
90
91
92
def on_training_start(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> None:
    pass

on_training_end

on_training_end(trainer, model, state)
Source code in src/formed/integrations/torch/training/callbacks.py
 94
 95
 96
 97
 98
 99
100
def on_training_end(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> TrainState:
    return state

on_epoch_start

on_epoch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/torch/training/callbacks.py
102
103
104
105
106
107
108
109
def on_epoch_start(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_epoch_end

on_epoch_end(trainer, model, state, epoch)
Source code in src/formed/integrations/torch/training/callbacks.py
111
112
113
114
115
116
117
118
def on_epoch_end(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_batch_start

on_batch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/torch/training/callbacks.py
120
121
122
123
124
125
126
127
def on_batch_start(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_batch_end

on_batch_end(trainer, model, state, epoch, output)
Source code in src/formed/integrations/torch/training/callbacks.py
129
130
131
132
133
134
135
136
137
def on_batch_end(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
    output: ModelOutputT,
) -> None:
    pass

on_eval_start

on_eval_start(trainer, model, state)
Source code in src/formed/integrations/torch/training/callbacks.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def on_eval_start(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> IEvaluator[ModelInputT, ModelOutputT]:
    class DummyMetric(IEvaluator):
        def update(self, inputs, output, /) -> None:
            pass

        def compute(self) -> dict[str, float]:
            return {}

        def reset(self) -> None:
            pass

        def clone(self) -> Self:
            return self

    return DummyMetric()

on_eval_end

on_eval_end(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/torch/training/callbacks.py
160
161
162
163
164
165
166
167
168
def on_eval_end(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    metrics: Mapping[str, float],
    prefix: str = "",
) -> None:
    pass

on_log

on_log(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/torch/training/callbacks.py
170
171
172
173
174
175
176
177
178
def on_log(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    metrics: Mapping[str, float],
    prefix: str = "",
) -> None:
    pass

EvaluationCallback

EvaluationCallback(evaluator)

Bases: TorchTrainingCallback, Generic[ModelInputT, ModelOutputT]

Callback for computing metrics using a custom evaluator.

This callback integrates a custom evaluator into the training loop, resetting it before each evaluation phase and returning it for metric accumulation.

PARAMETER DESCRIPTION
evaluator

Evaluator implementing the IEvaluator protocol.

TYPE: IEvaluator[ModelInputT, ModelOutputT]

Examples:

>>> from formed.integrations.ml.metrics import MulticlassAccuracy
>>>
>>> evaluator = MulticlassAccuracy()
>>> callback = EvaluationCallback(evaluator)
Source code in src/formed/integrations/torch/training/callbacks.py
200
201
def __init__(self, evaluator: IEvaluator[ModelInputT, ModelOutputT]) -> None:
    self._evaluator = evaluator

on_eval_start

on_eval_start(trainer, model, state)
Source code in src/formed/integrations/torch/training/callbacks.py
203
204
205
206
207
208
209
210
211
def on_eval_start(  # pyright: ignore[reportIncompatibleMethodOverride]
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> IEvaluator[ModelInputT, ModelOutputT]:
    evaluator = self._evaluator.clone()
    evaluator.reset()
    return evaluator

on_training_start

on_training_start(trainer, model, state)
Source code in src/formed/integrations/torch/training/callbacks.py
86
87
88
89
90
91
92
def on_training_start(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> None:
    pass

on_training_end

on_training_end(trainer, model, state)
Source code in src/formed/integrations/torch/training/callbacks.py
 94
 95
 96
 97
 98
 99
100
def on_training_end(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> TrainState:
    return state

on_epoch_start

on_epoch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/torch/training/callbacks.py
102
103
104
105
106
107
108
109
def on_epoch_start(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_epoch_end

on_epoch_end(trainer, model, state, epoch)
Source code in src/formed/integrations/torch/training/callbacks.py
111
112
113
114
115
116
117
118
def on_epoch_end(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_batch_start

on_batch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/torch/training/callbacks.py
120
121
122
123
124
125
126
127
def on_batch_start(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_batch_end

on_batch_end(trainer, model, state, epoch, output)
Source code in src/formed/integrations/torch/training/callbacks.py
129
130
131
132
133
134
135
136
137
def on_batch_end(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
    output: ModelOutputT,
) -> None:
    pass

on_eval_end

on_eval_end(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/torch/training/callbacks.py
160
161
162
163
164
165
166
167
168
def on_eval_end(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    metrics: Mapping[str, float],
    prefix: str = "",
) -> None:
    pass

on_log

on_log(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/torch/training/callbacks.py
170
171
172
173
174
175
176
177
178
def on_log(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    metrics: Mapping[str, float],
    prefix: str = "",
) -> None:
    pass

EarlyStoppingCallback

EarlyStoppingCallback(patience=5, metric='-train/loss')

Bases: TorchTrainingCallback

Callback for early stopping based on metric improvements.

This callback monitors a specified metric and stops training if it doesn't improve for a given number of evaluations (patience). The best model is automatically saved and restored at the end of training.

PARAMETER DESCRIPTION
patience

Number of evaluations without improvement before stopping.

TYPE: int DEFAULT: 5

metric

Metric to monitor. Prefix with - to maximize (e.g., "-loss"), or + to minimize (e.g., "+error"). Default is "-loss".

TYPE: str DEFAULT: '-train/loss'

Examples:

>>> # Stop if validation loss doesn't improve for 5 evaluations
>>> callback = EarlyStoppingCallback(patience=5, metric="-val/loss")
>>>
>>> # Stop if accuracy doesn't improve for 3 evaluations
>>> callback = EarlyStoppingCallback(patience=3, metric="+accuracy")
Note

The best model is saved to the step working directory and automatically restored when training ends early or completes.

Source code in src/formed/integrations/torch/training/callbacks.py
240
241
242
243
244
245
246
247
248
249
def __init__(
    self,
    patience: int = 5,
    metric: str = "-train/loss",
) -> None:
    self._patience = patience
    self._metric = metric.lstrip("-+")
    self._direction = -1 if metric.startswith("-") else 1
    self._best_metric = -float("inf")
    self._counter = 0

on_training_start

on_training_start(trainer, model, state)
Source code in src/formed/integrations/torch/training/callbacks.py
259
260
261
262
263
264
265
266
def on_training_start(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> None:
    self._best_metric = -float("inf")
    self._counter = 0

on_eval_end

on_eval_end(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/torch/training/callbacks.py
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
def on_eval_end(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    metrics: Mapping[str, float],
    prefix: str = "",
) -> None:
    import torch
    import torch.distributed as dist

    logger = use_step_logger(__name__)

    # For DDP: only rank 0 makes the early stopping decision
    should_stop = False

    if trainer.distributor.is_main_process:
        if prefix:
            metrics = {f"{prefix}{key}": value for key, value in metrics.items()}
        try:
            metric = self._direction * metrics[self._metric]
        except KeyError:
            return
        if metric > self._best_metric:
            self._best_metric = metric
            self._counter = 0
            # Save state_dict for serialization efficiency
            torch.save(state.state_dict(), self._get_checkpoint_path())
            logger.info(f"New best model saved with {self._metric}={self._best_metric:.4f}")
        else:
            self._counter += 1
            if self._counter >= self._patience:
                should_stop = True

    # For DDP: broadcast the early stopping decision from rank 0 to all processes
    if trainer.distributor.world_size > 1 and dist.is_initialized():
        # Create a tensor for broadcasting
        # Get device from model's first parameter
        device = next(state.model.parameters()).device if list(state.model.parameters()) else torch.device("cpu")
        should_stop_tensor = torch.tensor(1 if should_stop else 0, dtype=torch.int, device=device)
        dist.broadcast(should_stop_tensor, src=0)
        should_stop = bool(should_stop_tensor.item())

    # Synchronize all processes before potentially raising StopEarly
    trainer.distributor.barrier()

    if should_stop:
        raise StopEarly()

on_training_end

on_training_end(trainer, model, state)
Source code in src/formed/integrations/torch/training/callbacks.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
def on_training_end(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> TrainState:
    import torch
    import torch.distributed as dist

    logger = use_step_logger(__name__)

    # Synchronize before loading best model
    trainer.distributor.barrier()

    # Load best model if it exists
    if (checkpoint_path := self._get_checkpoint_path()).exists():
        if trainer.distributor.is_main_process:
            logger.info("Loading best state from early stopping checkpoint.")
            state_dict = torch.load(checkpoint_path, map_location="cpu")
            state.load_state_dict(state_dict)

        # For DDP: broadcast the model state from rank 0 to all other processes
        if trainer.distributor.world_size > 1 and dist.is_initialized():
            # Broadcast model parameters
            for param in state.model.parameters():
                dist.broadcast(param.data, src=0)
            # Broadcast optimizer state
            # Note: optimizer state broadcasting is complex, so we skip it for now
            # Users can re-initialize optimizer if needed after loading best model
            if trainer.distributor.is_main_process:
                logger.info("Broadcasted best model to all processes.")

    # Synchronize after loading so all processes have the best model
    trainer.distributor.barrier()
    return state

on_epoch_start

on_epoch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/torch/training/callbacks.py
102
103
104
105
106
107
108
109
def on_epoch_start(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_epoch_end

on_epoch_end(trainer, model, state, epoch)
Source code in src/formed/integrations/torch/training/callbacks.py
111
112
113
114
115
116
117
118
def on_epoch_end(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_batch_start

on_batch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/torch/training/callbacks.py
120
121
122
123
124
125
126
127
def on_batch_start(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_batch_end

on_batch_end(trainer, model, state, epoch, output)
Source code in src/formed/integrations/torch/training/callbacks.py
129
130
131
132
133
134
135
136
137
def on_batch_end(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
    output: ModelOutputT,
) -> None:
    pass

on_eval_start

on_eval_start(trainer, model, state)
Source code in src/formed/integrations/torch/training/callbacks.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def on_eval_start(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> IEvaluator[ModelInputT, ModelOutputT]:
    class DummyMetric(IEvaluator):
        def update(self, inputs, output, /) -> None:
            pass

        def compute(self) -> dict[str, float]:
            return {}

        def reset(self) -> None:
            pass

        def clone(self) -> Self:
            return self

    return DummyMetric()

on_log

on_log(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/torch/training/callbacks.py
170
171
172
173
174
175
176
177
178
def on_log(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    metrics: Mapping[str, float],
    prefix: str = "",
) -> None:
    pass

MlflowCallback

MlflowCallback()

Bases: TorchTrainingCallback

Callback for logging metrics to MLflow.

This callback automatically logs training and validation metrics to MLflow when used within a workflow step that has MLflow tracking enabled.

Examples:

>>> from formed.integrations.torch import TorchTrainer, MlflowCallback
>>>
>>> trainer = TorchTrainer(
...     train_dataloader=train_loader,
...     val_dataloader=val_loader,
...     callbacks=[MlflowCallback()]
... )
Note

Requires the formed mlflow integration and must be used within a workflow step with MLflow tracking configured.

Source code in src/formed/integrations/torch/training/callbacks.py
376
377
378
379
def __init__(self) -> None:
    from formed.integrations.mlflow.workflow import MlflowLogger

    self._mlflow_logger: Optional[MlflowLogger] = None

on_training_start

on_training_start(trainer, model, state)
Source code in src/formed/integrations/torch/training/callbacks.py
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
def on_training_start(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> None:
    from formed.integrations.mlflow.workflow import use_mlflow_logger
    from formed.workflow import use_step_logger

    # Only main process logs to MLflow
    if not trainer.distributor.is_main_process:
        return

    logger = use_step_logger(__name__)

    self._mlflow_logger = use_mlflow_logger()
    if self._mlflow_logger is None:
        logger.warning("MlflowLogger not found. Skipping logging.")

on_log

on_log(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/torch/training/callbacks.py
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
def on_log(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    metrics: Mapping[str, float],
    prefix: str = "",
) -> None:
    # Only main process logs to MLflow
    if not trainer.distributor.is_main_process:
        return

    metrics = {prefix + key: value for key, value in metrics.items()}
    if self._mlflow_logger is not None:
        # Log all metrics
        for key, value in metrics.items():
            self._mlflow_logger.log_metric(key, value, step=int(state.step))

        # Log learning rate
        learning_rate = state.get_learning_rate()
        if learning_rate is not None:
            self._mlflow_logger.log_metric("learning_rate", learning_rate, step=int(state.step))

        # Log gradient norm
        gradient_norm = state.get_gradient_norm()
        if gradient_norm is not None:
            self._mlflow_logger.log_metric("gradient_norm", gradient_norm, step=int(state.step))

on_epoch_end

on_epoch_end(trainer, model, state, epoch)
Source code in src/formed/integrations/torch/training/callbacks.py
428
429
430
431
432
433
434
435
436
437
438
439
440
def on_epoch_end(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    # Only main process logs to MLflow
    if not trainer.distributor.is_main_process:
        return

    if self._mlflow_logger is not None:
        self._mlflow_logger.log_metric("epoch", epoch, step=int(state.step))

on_training_end

on_training_end(trainer, model, state)
Source code in src/formed/integrations/torch/training/callbacks.py
 94
 95
 96
 97
 98
 99
100
def on_training_end(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> TrainState:
    return state

on_epoch_start

on_epoch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/torch/training/callbacks.py
102
103
104
105
106
107
108
109
def on_epoch_start(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_batch_start

on_batch_start(trainer, model, state, epoch)
Source code in src/formed/integrations/torch/training/callbacks.py
120
121
122
123
124
125
126
127
def on_batch_start(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
) -> None:
    pass

on_batch_end

on_batch_end(trainer, model, state, epoch, output)
Source code in src/formed/integrations/torch/training/callbacks.py
129
130
131
132
133
134
135
136
137
def on_batch_end(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    epoch: int,
    output: ModelOutputT,
) -> None:
    pass

on_eval_start

on_eval_start(trainer, model, state)
Source code in src/formed/integrations/torch/training/callbacks.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def on_eval_start(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
) -> IEvaluator[ModelInputT, ModelOutputT]:
    class DummyMetric(IEvaluator):
        def update(self, inputs, output, /) -> None:
            pass

        def compute(self) -> dict[str, float]:
            return {}

        def reset(self) -> None:
            pass

        def clone(self) -> Self:
            return self

    return DummyMetric()

on_eval_end

on_eval_end(trainer, model, state, metrics, prefix='')
Source code in src/formed/integrations/torch/training/callbacks.py
160
161
162
163
164
165
166
167
168
def on_eval_end(
    self,
    trainer: "TorchTrainer[ItemT, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    state: TrainState,
    metrics: Mapping[str, float],
    prefix: str = "",
) -> None:
    pass

formed.integrations.torch.training.engine

Training engine abstractions for PyTorch models.

This module provides the training engine abstraction that defines how models are trained and evaluated. Engines handle loss computation, gradient calculation, and parameter updates.

Key Components
  • TorchTrainingEngine: Abstract base class for training engines
  • DefaultTorchTrainingEngine: Default implementation with automatic differentiation
Features
  • Customizable loss functions
  • Automatic gradient computation using PyTorch autograd
  • State creation and management
  • Separate train and eval steps
  • Compatible with TorchTrainer and distributors

Examples:

>>> from formed.integrations.torch import DefaultTorchTrainingEngine
>>>
>>> # Create engine with custom loss accessor
>>> engine = DefaultTorchTrainingEngine(loss="total_loss")
>>>
>>> # Or with custom loss function
>>> def custom_loss(output):
...     return output.loss + 0.1 * output.regularization
>>> engine = DefaultTorchTrainingEngine(loss=custom_loss)

TorchTrainingEngine

Bases: ABC, Registrable, Generic[ModelInputT, ModelOutputT, ModelParamsT]

Abstract base class for PyTorch training engines.

A training engine defines how models are trained by implementing state creation, training steps, and evaluation steps. This allows for custom training loops and loss computations.

CLASS TYPE PARAMETER DESCRIPTION
ModelInputT

Type of model input.

ModelOutputT

Type of model output.

ModelParamsT

Type of additional parameters.

create_state abstractmethod

create_state(trainer, model)

Create initial training state from model and trainer.

PARAMETER DESCRIPTION
trainer

Trainer instance.

TYPE: TorchTrainer[Any, ModelInputT, ModelOutputT, ModelParamsT]

model

Model to train.

TYPE: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT]

RETURNS DESCRIPTION
TrainState

Initial training state.

Source code in src/formed/integrations/torch/training/engine.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
@abc.abstractmethod
def create_state(
    self,
    trainer: "TorchTrainer[Any, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
) -> TrainState:
    """Create initial training state from model and trainer.

    Args:
        trainer: Trainer instance.
        model: Model to train.

    Returns:
        Initial training state.

    """
    raise NotImplementedError

train_step abstractmethod

train_step(inputs, state, trainer)

Execute a single training step.

PARAMETER DESCRIPTION
inputs

Batch of training inputs.

TYPE: ModelInputT

state

Current training state (model and optimizer are updated in-place).

TYPE: TrainState

trainer

Trainer instance.

TYPE: TorchTrainer[Any, ModelInputT, ModelOutputT, ModelParamsT]

RETURNS DESCRIPTION
ModelOutputT

Model output.

Note

This method updates the state in-place for efficiency. The step counter is incremented automatically.

Source code in src/formed/integrations/torch/training/engine.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
@abc.abstractmethod
def train_step(
    self,
    inputs: ModelInputT,
    state: TrainState,
    trainer: "TorchTrainer[Any, ModelInputT, ModelOutputT, ModelParamsT]",
) -> ModelOutputT:
    """Execute a single training step.

    Args:
        inputs: Batch of training inputs.
        state: Current training state (model and optimizer are updated in-place).
        trainer: Trainer instance.

    Returns:
        Model output.

    Note:
        This method updates the state in-place for efficiency.
        The step counter is incremented automatically.

    """
    raise NotImplementedError

eval_step abstractmethod

eval_step(inputs, state, trainer)

Execute a single evaluation step.

PARAMETER DESCRIPTION
inputs

Batch of evaluation inputs.

TYPE: ModelInputT

state

Current training state.

TYPE: TrainState

trainer

Trainer instance.

TYPE: TorchTrainer[Any, ModelInputT, ModelOutputT, ModelParamsT]

RETURNS DESCRIPTION
ModelOutputT

Model output.

Source code in src/formed/integrations/torch/training/engine.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
@abc.abstractmethod
def eval_step(
    self,
    inputs: ModelInputT,
    state: TrainState,
    trainer: "TorchTrainer[Any, ModelInputT, ModelOutputT, ModelParamsT]",
) -> ModelOutputT:
    """Execute a single evaluation step.

    Args:
        inputs: Batch of evaluation inputs.
        state: Current training state.
        trainer: Trainer instance.

    Returns:
        Model output.

    """
    raise NotImplementedError

DefaultTorchTrainingEngine

DefaultTorchTrainingEngine(
    optimizer=None,
    lr_scheduler=None,
    loss="loss",
    gradient_accumulation_steps=1,
    max_grad_norm=None,
    params=None,
    dtype=None,
    grad_scaler=None,
)

Bases: TorchTrainingEngine[ModelInputT, ModelOutputT, ModelParamsT]

Default training engine using automatic differentiation.

This engine computes gradients using PyTorch's autograd and updates parameters using the provided optimizer. Loss is extracted from model output either by attribute name or custom function.

PARAMETER DESCRIPTION
optimizer

Optimizer factory or instance. Can be a Lazy object, callable that takes model parameters, or an optimizer instance.

TYPE: Optional[OptimizerFactory] DEFAULT: None

lr_scheduler

Optional learning rate scheduler factory or instance. Can be a Lazy object, callable that takes optimizer, a sequence of schedulers (will be chained), or a scheduler instance.

TYPE: Union[LRSchedulerFactory, Sequence[LRSchedulerFactory], None] DEFAULT: None

loss

Loss accessor - either attribute name (e.g., "loss") or callable that extracts loss from model output.

TYPE: Union[str, Callable[[ModelOutputT], Tensor]] DEFAULT: 'loss'

gradient_accumulation_steps

Number of steps to accumulate gradients before performing an optimizer step.

TYPE: int DEFAULT: 1

max_grad_norm

Maximum gradient norm for clipping. If None, no clipping is applied.

TYPE: Optional[float] DEFAULT: None

params

Optional additional parameters to pass to the model during training.

TYPE: ModelParamsT | None DEFAULT: None

dtype

Data type for mixed precision training ("float32", "float16", "bfloat16").

TYPE: Literal['float32', 'float16', 'bfloat16'] | None DEFAULT: None

grad_scaler

Gradient scaler for mixed precision training.

TYPE: Lazy[GradScaler | IGradScaler] | IGradScaler | None DEFAULT: None

Examples:

>>> # Basic usage with optimizer
>>> engine = DefaultTorchTrainingEngine(
...     optimizer=torch.optim.Adam,
...     loss="loss"
... )
>>>
>>> # With learning rate scheduler and gradient clipping
>>> engine = DefaultTorchTrainingEngine(
...     optimizer=Lazy(cls=torch.optim.Adam, config={"lr": 1e-3}),
...     lr_scheduler=Lazy(cls=torch.optim.lr_scheduler.CosineAnnealingLR, config={"T_max": 100}),
...     max_grad_norm=1.0,
...     loss=lambda output: output.loss + 0.01 * output.regularization
... )
>>>
>>> # With mixed precision training
>>> engine = DefaultTorchTrainingEngine(
...     optimizer=torch.optim.AdamW,
...     dtype="bfloat16",
...     grad_scaler=Lazy(cls=torch.amp.GradScaler),
...     max_grad_norm=1.0
... )
Source code in src/formed/integrations/torch/training/engine.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def __init__(
    self,
    optimizer: Optional[OptimizerFactory] = None,
    lr_scheduler: Union[LRSchedulerFactory, Sequence[LRSchedulerFactory], None] = None,
    loss: Union[str, Callable[[ModelOutputT], torch.Tensor]] = "loss",
    gradient_accumulation_steps: int = 1,
    max_grad_norm: Optional[float] = None,
    params: ModelParamsT | None = None,
    dtype: Literal["float32", "float16", "bfloat16"] | None = None,
    grad_scaler: Lazy[torch.amp.grad_scaler.GradScaler | IGradScaler] | IGradScaler | None = None,
) -> None:
    assert dtype in (None, "float32", "float16", "bfloat16"), (
        "dtype must be one of None, 'float32', 'float16', or 'bfloat16'"
    )

    super().__init__()
    self._optimizer_factory = optimizer or get_default_optimizer_factory()
    self._lr_scheduler_factory = lr_scheduler or get_default_lr_scheduler_factory()
    self._loss = partial(xgetattr, name=loss) if isinstance(loss, str) else loss
    self._gradient_accumulation_steps = gradient_accumulation_steps
    self._max_grad_norm = max_grad_norm
    self._params = params
    self._dtype = getattr(torch, dtype) if dtype is not None else None
    self._grad_scaler = grad_scaler

create_state

create_state(trainer, model)
Source code in src/formed/integrations/torch/training/engine.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def create_state(
    self,
    trainer: "TorchTrainer[Any, ModelInputT, ModelOutputT, ModelParamsT]",
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
) -> TrainState:
    # Construct optimizer
    optimizer: IOptimizer
    if isinstance(self._optimizer_factory, Lazy):
        optimizer = self._optimizer_factory.construct(params=model.parameters())
    elif callable(self._optimizer_factory):
        optimizer = self._optimizer_factory(model.parameters())
    else:
        optimizer = self._optimizer_factory

    # Construct lr_scheduler
    lr_scheduler: ILRScheduler | None = None
    if self._lr_scheduler_factory is not None:

        def _construct_lr_scheduler(factory: LRSchedulerFactory) -> ILRScheduler:
            if isinstance(factory, Lazy):
                return factory.construct(optimizer=optimizer)
            elif callable(factory):
                return factory(optimizer)
            return factory

        if isinstance(self._lr_scheduler_factory, Sequence):
            lr_scheduler = torch.optim.lr_scheduler.ChainedScheduler(
                [
                    cast(torch.optim.lr_scheduler.LRScheduler, _construct_lr_scheduler(scheduler))
                    for scheduler in self._lr_scheduler_factory
                ],
                optimizer=cast(torch.optim.Optimizer, optimizer),
            )
        else:
            lr_scheduler = _construct_lr_scheduler(self._lr_scheduler_factory)

    # Initialize grad_scaler only if requested and on appropriate device
    grad_scaler: IGradScaler | None = None
    if isinstance(self._grad_scaler, IGradScaler):
        grad_scaler = self._grad_scaler
    elif isinstance(self._grad_scaler, Lazy):
        if device := get_device():
            grad_scaler = self._grad_scaler.construct(device=device.type)
        else:
            warnings.warn("GradScaler requested but device is not set. GradScaler will not be used.")

    return TrainState(
        model=model,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        step=0,
        grad_scaler=grad_scaler,
    )

train_step

train_step(inputs, state, trainer)
Source code in src/formed/integrations/torch/training/engine.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
def train_step(
    self,
    inputs: ModelInputT,
    state: TrainState,
    trainer: "TorchTrainer[Any, ModelInputT, ModelOutputT, ModelParamsT]",
) -> ModelOutputT:
    del trainer

    # Set model to training mode
    state.model.train()

    # Zero gradients at the start of accumulation cycle
    if state.step % self._gradient_accumulation_steps == 0:
        state.optimizer.zero_grad()

    with ExitStack() as stack:
        if (device := get_device()) is not None and self._dtype is not None:
            stack.enter_context(torch.autocast(device_type=device.type, dtype=self._dtype))

        output = state.model(inputs, params=self._params)

        try:
            loss = self._loss(output)
        except (KeyError, AttributeError) as e:
            raise ValueError(
                f"Failed to extract loss from model output. "
                f"Error: {e}. "
                f"Output type: {type(output).__name__}. "
                "Please ensure your model's forward() method returns output with a 'loss' attribute or key."
            ) from e

    if loss is None:
        raise ValueError(
            "Model output loss is None. "
            "This typically happens when labels are not provided during training. "
            "Please ensure your training data includes labels."
        )

    # Scale loss for gradient accumulation
    loss = loss / self._gradient_accumulation_steps

    # Backward pass with or without gradient scaling
    if state.grad_scaler is not None:
        state.grad_scaler.scale(loss).backward()
    else:
        loss.backward()

    # Update optimizer and scheduler when accumulation is complete
    if (state.step + 1) % self._gradient_accumulation_steps == 0:
        # Clip gradients if max_grad_norm is specified
        if self._max_grad_norm is not None:
            if state.grad_scaler is not None:
                if isinstance(state.optimizer, torch.optim.Optimizer):
                    # Unscale gradients before clipping when using grad_scaler
                    state.grad_scaler.unscale_(state.optimizer)
                else:
                    warnings.warn(
                        "Cannot unscale gradients for gradient clipping because "
                        "the optimizer is not a torch.optim.Optimizer instance."
                    )
            torch.nn.utils.clip_grad_norm_(state.model.parameters(), self._max_grad_norm)

        if state.grad_scaler is not None:
            # GradScaler.step expects torch.optim.Optimizer
            state.grad_scaler.step(cast(torch.optim.Optimizer, state.optimizer))
            state.grad_scaler.update()
        else:
            state.optimizer.step()
        if state.lr_scheduler is not None:
            state.lr_scheduler.step()

    # Increment step counter (counts micro-batches)
    state.step += 1

    return output

eval_step

eval_step(inputs, state, trainer)
Source code in src/formed/integrations/torch/training/engine.py
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def eval_step(
    self,
    inputs: ModelInputT,
    state: TrainState,
    trainer: "TorchTrainer[Any, ModelInputT, ModelOutputT, ModelParamsT]",
) -> ModelOutputT:
    del trainer

    # Set model to eval mode
    state.model.eval()

    # Standard PyTorch evaluation step
    with torch.no_grad():
        output = state.model(inputs)

    return output

get_default_optimizer_factory

get_default_optimizer_factory()

Get a default optimizer factory (Adam with lr=1e-3).

Source code in src/formed/integrations/torch/training/engine.py
138
139
140
def get_default_optimizer_factory() -> OptimizerFactory:
    """Get a default optimizer factory (Adam with lr=1e-3)."""
    return Lazy(config={"lr": 1e-3}, cls=torch.optim.Adam)

get_default_lr_scheduler_factory

get_default_lr_scheduler_factory()

Get a default learning rate scheduler factory (None).

Source code in src/formed/integrations/torch/training/engine.py
143
144
145
def get_default_lr_scheduler_factory() -> Optional[LRSchedulerFactory]:
    """Get a default learning rate scheduler factory (None)."""
    return None

formed.integrations.torch.training.exceptions

StopEarly

Bases: Exception

Raised to stop training early.

formed.integrations.torch.training.state

Training state management for PyTorch models.

This module provides a training state class that encapsulates model parameters, optimizer state, and training progress for PyTorch models.

TrainState

TrainState(
    model,
    optimizer,
    step=0,
    lr_scheduler=None,
    grad_scaler=None,
)

Training state for PyTorch models.

This class encapsulates the training state including model, optimizer, learning rate scheduler, and training progress counters. Unlike the Flax version, this directly holds references to the model and optimizer for efficiency.

ATTRIBUTE DESCRIPTION
model

The PyTorch model being trained.

optimizer

The optimizer for training.

lr_scheduler

Optional learning rate scheduler.

step

Training step counter.

grad_scaler

Optional gradient scaler for mixed precision training.

Examples:

>>> # Create state from model and optimizer
>>> state = TrainState(
...     model=model,
...     optimizer=optimizer,
...     step=0
... )
>>>
>>> # Access model and optimizer directly
>>> state.model.train()
>>> state.optimizer.zero_grad()
Source code in src/formed/integrations/torch/training/state.py
45
46
47
48
49
50
51
52
53
54
55
56
57
def __init__(
    self,
    model: torch.nn.Module,
    optimizer: "IOptimizer",
    step: int = 0,
    lr_scheduler: Optional["ILRScheduler"] = None,
    grad_scaler: Optional["IGradScaler"] = None,
) -> None:
    self.model = model
    self.optimizer = optimizer
    self.lr_scheduler = lr_scheduler
    self.step = step
    self.grad_scaler = grad_scaler

model instance-attribute

model = model

optimizer instance-attribute

optimizer = optimizer

lr_scheduler instance-attribute

lr_scheduler = lr_scheduler

step instance-attribute

step = step

grad_scaler instance-attribute

grad_scaler = grad_scaler

state_dict

state_dict()

Get state dictionary for serialization.

RETURNS DESCRIPTION
dict[str, Any]

Dictionary containing model state, optimizer state, lr_scheduler state (if present), grad_scaler state (if present), and step.

Source code in src/formed/integrations/torch/training/state.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def state_dict(self) -> dict[str, Any]:
    """Get state dictionary for serialization.

    Returns:
        Dictionary containing model state, `optimizer` state, `lr_scheduler` state (if present), `grad_scaler` state (if present), and `step`.

    """
    state = {
        "model_state": self.model.state_dict(),
        "optimizer_state": self.optimizer.state_dict(),
        "step": self.step,
    }
    if self.lr_scheduler is not None:
        state["lr_scheduler_state"] = self.lr_scheduler.state_dict()
    if self.grad_scaler is not None:
        state["grad_scaler_state"] = self.grad_scaler.state_dict()
    return state

load_state_dict

load_state_dict(state_dict)

Load state from dictionary.

PARAMETER DESCRIPTION
state_dict

Dictionary containing model state, optimizer state, lr_scheduler state (optional), grad_scaler state (optional), and step.

TYPE: dict[str, Any]

Source code in src/formed/integrations/torch/training/state.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
    """Load state from dictionary.

    Args:
        state_dict: Dictionary containing model state, optimizer state, lr_scheduler state (optional), grad_scaler state (optional), and step.

    """
    self.model.load_state_dict(state_dict["model_state"])
    self.optimizer.load_state_dict(state_dict["optimizer_state"])
    self.step = state_dict["step"]
    if "lr_scheduler_state" in state_dict and self.lr_scheduler is not None:
        self.lr_scheduler.load_state_dict(state_dict["lr_scheduler_state"])
    if "grad_scaler_state" in state_dict and self.grad_scaler is not None:
        self.grad_scaler.load_state_dict(state_dict["grad_scaler_state"])

get_learning_rate

get_learning_rate()

Get current learning rate from optimizer.

RETURNS DESCRIPTION
Optional[float]

Current learning rate from the first parameter group, or None if unavailable.

Examples:

>>> lr = state.get_learning_rate()
>>> if lr is not None:
...     print(f"Current learning rate: {lr}")
Source code in src/formed/integrations/torch/training/state.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def get_learning_rate(self) -> Optional[float]:
    """Get current learning rate from optimizer.

    Returns:
        Current learning rate from the first parameter group, or `None` if unavailable.

    Examples:
        >>> lr = state.get_learning_rate()
        >>> if lr is not None:
        ...     print(f"Current learning rate: {lr}")

    """
    import torch.optim

    if isinstance(self.optimizer, torch.optim.Optimizer):
        if self.optimizer.param_groups:
            return self.optimizer.param_groups[0]["lr"]
    return None

get_gradient_norm

get_gradient_norm()

Compute L2 norm of all gradients.

RETURNS DESCRIPTION
Optional[float]

L2 norm of all parameter gradients, or None if no gradients are available.

Examples:

>>> grad_norm = state.get_gradient_norm()
>>> if grad_norm is not None:
...     print(f"Gradient norm: {grad_norm:.4f}")
Note

This method computes the gradient norm on-demand. It should be called after backward() but before optimizer.step() or zero_grad() to get meaningful results.

Source code in src/formed/integrations/torch/training/state.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def get_gradient_norm(self) -> Optional[float]:
    """Compute L2 norm of all gradients.

    Returns:
        L2 norm of all parameter gradients, or `None` if no gradients are available.

    Examples:
        >>> grad_norm = state.get_gradient_norm()
        >>> if grad_norm is not None:
        ...     print(f"Gradient norm: {grad_norm:.4f}")

    Note:
        This method computes the gradient norm on-demand. It should be called
        after `backward()` but before optimizer.step() or `zero_grad()` to get
        meaningful results.

    """
    total_norm = 0.0
    has_gradients = False

    for param in self.model.parameters():
        if param.grad is not None:
            has_gradients = True
            param_norm = param.grad.data.norm(2)
            total_norm += param_norm.item() ** 2

    if not has_gradients:
        return None

    return total_norm**0.5

formed.integrations.torch.training.trainer

High-level trainer for PyTorch models.

This module provides the TorchTrainer class, which orchestrates the complete training process for PyTorch models including data loading, optimization, evaluation, callbacks, and distributed training.

Key Features
  • Flexible training loop with epoch and step-based logging/evaluation
  • Support for callbacks at various training stages
  • Distributed training via data parallelism
  • Integration with PyTorch optimizers
  • Rich progress bars with training metrics
  • Early stopping and checkpointing
  • MLflow integration

Examples:

>>> from formed.integrations.torch import (
...     TorchTrainer,
...     EvaluationCallback,
...     EarlyStoppingCallback
... )
>>> from formed.integrations.ml import DataLoader, BasicBatchSampler
>>> import torch.optim as optim
>>>
>>> # Setup data loaders and engine
>>> train_dataloader = DataLoader(
...     sampler=BasicBatchSampler(batch_size=32, shuffle=True),
...     collator=datamodule.batch
... )
>>> engine = DefaultTorchTrainingEngine(
...     optimizer=optim.Adam,
...     lr_scheduler=optim.lr_scheduler.StepLR
... )
>>>
>>> # Create trainer
>>> trainer = TorchTrainer(
...     train_dataloader=train_dataloader,
...     val_dataloader=val_dataloader,
...     engine=engine,
...     max_epochs=10,
...     callbacks=[
...         EvaluationCallback(my_evaluator),
...         EarlyStoppingCallback(patience=3)
...     ]
... )
>>>
>>> # Train model
>>> state = trainer.train(model, train_dataset, val_dataset)

TorchTrainer

TorchTrainer(
    *,
    train_dataloader,
    val_dataloader=None,
    engine=None,
    callbacks=(),
    distributor=None,
    max_epochs=None,
    eval_strategy="epoch",
    eval_interval=1,
    logging_strategy="epoch",
    logging_interval=1,
    logging_first_step=True,
    train_prefix="train/",
    val_prefix="val/",
)

Bases: Generic[ItemT, ModelInputT, ModelOutputT, ModelParamsT]

High-level trainer for PyTorch models.

TorchTrainer provides a complete training loop with support for distributed training, callbacks, evaluation, and metric logging. It handles the coordination of data loading, model training, evaluation, and callback execution.

CLASS TYPE PARAMETER DESCRIPTION
ItemT

Type of raw dataset items.

ModelInputT

Type of batched model inputs.

ModelOutputT

Type of model outputs.

ModelParamsT

Type of additional model parameters.

PARAMETER DESCRIPTION
train_dataloader

Data loader for training dataset.

TYPE: IDataLoader[ItemT, ModelInputT]

val_dataloader

Optional data loader for validation dataset.

TYPE: Optional[IDataLoader[ItemT, ModelInputT]] DEFAULT: None

engine

Training engine (defaults to DefaultTorchTrainingEngine).

TYPE: Optional[TorchTrainingEngine[ModelInputT, ModelOutputT, ModelParamsT]] DEFAULT: None

callbacks

Sequence of training callbacks.

TYPE: Sequence[TorchTrainingCallback] DEFAULT: ()

distributor

Device distributor (defaults to SingleDeviceDistributor).

TYPE: Optional[BaseDistributor] DEFAULT: None

max_epochs

Maximum number of training epochs.

TYPE: Optional[int] DEFAULT: None

eval_strategy

When to evaluate - "epoch" or "step".

TYPE: Literal['epoch', 'step'] DEFAULT: 'epoch'

eval_interval

Evaluation interval (epochs or steps).

TYPE: int DEFAULT: 1

logging_strategy

When to log - "epoch" or "step".

TYPE: Literal['epoch', 'step'] DEFAULT: 'epoch'

logging_interval

Logging interval (epochs or steps).

TYPE: int DEFAULT: 1

logging_first_step

Whether to log after the first training step.

TYPE: bool DEFAULT: True

train_prefix

Prefix for training metrics logging. Default is "train/".

TYPE: str DEFAULT: 'train/'

val_prefix

Prefix for validation metrics logging. Default is "val/".

TYPE: str DEFAULT: 'val/'

Examples:

>>> engine = DefaultTorchTrainingEngine(
...     optimizer=torch.optim.Adam,
...     lr_scheduler=torch.optim.lr_scheduler.StepLR
... )
>>> trainer = TorchTrainer(
...     train_dataloader=train_loader,
...     val_dataloader=val_loader,
...     engine=engine,
...     max_epochs=10,
...     eval_strategy="epoch",
...     logging_strategy="step",
...     logging_interval=100
... )
Source code in src/formed/integrations/torch/training/trainer.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def __init__(
    self,
    *,
    train_dataloader: IDataLoader[ItemT, ModelInputT],
    val_dataloader: Optional[IDataLoader[ItemT, ModelInputT]] = None,
    engine: Optional[TorchTrainingEngine[ModelInputT, ModelOutputT, ModelParamsT]] = None,
    callbacks: Sequence[TorchTrainingCallback] = (),
    distributor: Optional[BaseDistributor] = None,
    max_epochs: Optional[int] = None,
    eval_strategy: Literal["epoch", "step"] = "epoch",
    eval_interval: int = 1,
    logging_strategy: Literal["epoch", "step"] = "epoch",
    logging_interval: int = 1,
    logging_first_step: bool = True,
    train_prefix: str = "train/",
    val_prefix: str = "val/",
) -> None:
    self._train_dataloader = train_dataloader
    self._val_dataloader = val_dataloader
    self._engine = engine or DefaultTorchTrainingEngine[ModelInputT, ModelOutputT, ModelParamsT]()
    self._distributor = distributor or get_default_distributor()
    self._max_epochs = max_epochs or get_default_max_epochs()
    self._eval_strategy = eval_strategy
    self._eval_interval = eval_interval
    self._logging_strategy = logging_strategy
    self._logging_interval = logging_interval
    self._logging_first_step = logging_first_step
    self._callbacks = callbacks
    self._train_prefix = train_prefix
    self._val_prefix = val_prefix

distributor property

distributor

train

train(model, train_dataset, val_dataset=None, state=None)

Train a model on the provided datasets.

PARAMETER DESCRIPTION
model

Model to train.

TYPE: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT]

train_dataset

Sequence of training items.

TYPE: Sequence[ItemT]

val_dataset

Optional sequence of validation items.

TYPE: Optional[Sequence[ItemT]] DEFAULT: None

state

Optional pre-initialized training state (for resuming).

TYPE: Optional[TrainState] DEFAULT: None

RETURNS DESCRIPTION
TrainState

Final training state with trained parameters.

RAISES DESCRIPTION
ValueError

If val_dataset is provided but val_dataloader is not.

Examples:

>>> state = trainer.train(
...     model, train_items, val_items
... )
>>> # Load trained parameters
>>> model.load_state_dict(state.model_state)
Source code in src/formed/integrations/torch/training/trainer.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def train(
    self,
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    train_dataset: Sequence[ItemT],
    val_dataset: Optional[Sequence[ItemT]] = None,
    state: Optional[TrainState] = None,
) -> TrainState:
    """Train a model on the provided datasets.

    Args:
        model: Model to train.
        train_dataset: Sequence of training items.
        val_dataset: Optional sequence of validation items.
        state: Optional pre-initialized training state (for resuming).

    Returns:
        Final training state with trained parameters.

    Raises:
        ValueError: If `val_dataset` is provided but `val_dataloader` is not.

    Examples:
        >>> state = trainer.train(
        ...     model, train_items, val_items
        ... )
        >>> # Load trained parameters
        >>> model.load_state_dict(state.model_state)

    """
    if val_dataset is not None and self._val_dataloader is None:
        raise ValueError("Validation dataloader is not provided.")

    logger = use_step_logger(__name__)

    # Set device context for ensure_torch_tensor
    with use_device(self._distributor.device):
        return self._train_impl(model, train_dataset, val_dataset, state, logger)

get_default_max_epochs

get_default_max_epochs()

Get a default maximum number of training epochs.

Source code in src/formed/integrations/torch/training/trainer.py
81
82
83
def get_default_max_epochs() -> int:
    """Get a default maximum number of training epochs."""
    return 10

get_default_distributor

get_default_distributor()

Get a default single-device distributor.

Source code in src/formed/integrations/torch/training/trainer.py
86
87
88
def get_default_distributor() -> BaseDistributor:
    """Get a default single-device distributor."""
    return SingleDeviceDistributor()

formed.integrations.torch.workflow

Workflow integration for PyTorch model training.

This module provides workflow steps for training PyTorch models, allowing them to be integrated into the formed workflow system with automatic caching and dependency tracking.

Available Steps
  • torch::train: Train a PyTorch model using the provided trainer.
  • torch::evaluate: Evaluate a PyTorch model on a dataset.
  • torch::predict: Generate predictions on a dataset using a PyTorch model.
  • torch::predict_without_caching: Generate predictions without caching (same as torch::predict but uncached).

Examples:

>>> from formed.integrations.torch import train_torch_model
>>>
>>> # In workflow configuration (jsonnet):
>>> # {
>>> #   steps: {
>>> #     train: {
>>> #       type: "torch::train",
>>> #       model: { type: "my_model", ... },
>>> #       trainer: { type: "torch_trainer", ... },
>>> #       train_dataset: { type: "ref", ref: "preprocess" },
>>> #       random_seed: 42
>>> #     }
>>> #   }
>>> # }

TorchModelFormat

Bases: Format[_ModelT]

identifier property

identifier

Get the unique identifier for this format.

RETURNS DESCRIPTION
str

Format identifier string.

write

write(artifact, directory)
Source code in src/formed/integrations/torch/workflow.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def write(self, artifact: _ModelT, directory: Path) -> None:
    if artifact.__model_config__ is not None:
        config = dict(artifact.__model_config__)
        config[COLT_TYPEKEY] = f"{artifact.__class__.__module__}:{artifact.__class__.__name__}"
        self._get_config_path(directory).write_text(
            json.dumps(
                artifact.__model_config__,
                indent=2,
                cls=WorkflowJSONEncoder,
            )
        )
        torch.save(
            artifact.state_dict(),
            self._get_state_path(directory),
        )
    else:
        with self._get_pickle_path(directory).open("wb") as f:
            cloudpickle.dump(artifact, f)

read

read(directory)
Source code in src/formed/integrations/torch/workflow.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def read(self, directory: Path) -> _ModelT:
    if (pickle_path := self._get_pickle_path(directory)).exists():
        with pickle_path.open("rb") as f:
            model = cloudpickle.load(f)
        return cast(_ModelT, model)

    config = json.loads(
        self._get_config_path(directory).read_text(),
        cls=WorkflowJSONDecoder,
    )
    state_dict = torch.load(self._get_state_path(directory), map_location="cpu")
    model = COLT_BUILDER(config, BaseTorchModel)
    model.load_state_dict(state_dict)
    return cast(_ModelT, model)

is_default_of classmethod

is_default_of(obj)

Check if this format is the default for the given object type.

PARAMETER DESCRIPTION
obj

Object to check.

TYPE: Any

RETURNS DESCRIPTION
bool

True if this format should be used by default for this type.

Source code in src/formed/workflow/format.py
101
102
103
104
105
106
107
108
109
110
111
112
@classmethod
def is_default_of(cls, obj: Any) -> bool:
    """Check if this format is the default for the given object type.

    Args:
        obj: Object to check.

    Returns:
        True if this format should be used by default for this type.

    """
    return False

train_torch_model

train_torch_model(
    model,
    trainer,
    train_dataset,
    val_dataset=None,
    random_seed=0,
)

Train a PyTorch model using the provided trainer.

This workflow step trains a PyTorch model on the provided datasets, returning the trained model. The training process is cached based on the model architecture, trainer configuration, and dataset fingerprints.

PARAMETER DESCRIPTION
model

PyTorch model to train.

TYPE: Lazy[BaseTorchModel]

trainer

Trainer configuration with dataloaders and callbacks.

TYPE: TorchTrainer

train_dataset

Training dataset items.

TYPE: Sequence[ItemT]

val_dataset

Optional validation dataset items.

TYPE: Sequence[ItemT] | None DEFAULT: None

random_seed

Random seed for reproducibility.

TYPE: int DEFAULT: 0

RETURNS DESCRIPTION
BaseTorchModel

Trained PyTorch model with updated parameters.

Examples:

>>> # Use in Python code
>>> trained_model = train_torch_model(
...     model=my_model,
...     trainer=trainer,
...     train_dataset=train_data,
...     val_dataset=val_data,
...     random_seed=42
... )
Source code in src/formed/integrations/torch/workflow.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
@step("torch::train", format=TorchModelFormat())
def train_torch_model(
    model: Lazy[BaseTorchModel],
    trainer: TorchTrainer,
    train_dataset: Sequence[ItemT],
    val_dataset: Sequence[ItemT] | None = None,
    random_seed: int = 0,
) -> BaseTorchModel:
    """Train a PyTorch model using the provided trainer.

    This workflow step trains a PyTorch model on the provided datasets,
    returning the trained model. The training process is cached based on
    the model architecture, trainer configuration, and dataset fingerprints.

    Args:
        model: PyTorch model to train.
        trainer: Trainer configuration with dataloaders and callbacks.
        train_dataset: Training dataset items.
        val_dataset: Optional validation dataset items.
        random_seed: Random seed for reproducibility.

    Returns:
        Trained PyTorch model with updated parameters.

    Examples:
        >>> # Use in Python code
        >>> trained_model = train_torch_model(
        ...     model=my_model,
        ...     trainer=trainer,
        ...     train_dataset=train_data,
        ...     val_dataset=val_data,
        ...     random_seed=42
        ... )

    """
    # Set random seeds for reproducibility
    set_random_seed(random_seed)

    # Build model from Lazy
    model_instance = model.construct()

    # Set config for selialization
    model_instance.__model_config__ = model.config

    # Train the model
    state = trainer.train(model_instance, train_dataset, val_dataset)

    # Return the trained model
    return cast(BaseTorchModel, state.model)

evaluate_torch_model

evaluate_torch_model(
    model,
    evaluator,
    dataset,
    dataloader,
    params=None,
    random_seed=None,
    device=None,
)

Evaluate a PyTorch model on a dataset using the provided evaluator.

This workflow step evaluates a PyTorch model on the provided dataset, computing metrics using the evaluator. Evaluation is performed in evaluation mode (no gradient computation).

PARAMETER DESCRIPTION
model

PyTorch model to evaluate.

TYPE: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT]

evaluator

Evaluator to compute metrics.

TYPE: IEvaluator[ModelInputT, ModelOutputT]

dataset

Dataset items for evaluation.

TYPE: Iterable[ItemT]

dataloader

DataLoader to convert items to model inputs.

TYPE: IStreamingDataLoader[ItemT, ModelInputT]

params

Optional model parameters to use for evaluation.

TYPE: ModelParamsT | None DEFAULT: None

random_seed

Optional random seed for reproducibility.

TYPE: int | None DEFAULT: None

device

Optional device (e.g., "cpu", "cuda") to run evaluation on.

TYPE: str | device | None DEFAULT: None

RETURNS DESCRIPTION
dict[str, float]

Dictionary of computed evaluation metrics.

Examples:

>>> # Use in Python code
>>> metrics = evaluate_torch_model(
...     model=trained_model,
...     evaluator=my_evaluator,
...     dataset=test_data,
...     dataloader=test_loader
... )
Source code in src/formed/integrations/torch/workflow.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
@step("torch::evaluate", format="json")
def evaluate_torch_model(
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    evaluator: IEvaluator[ModelInputT, ModelOutputT],
    dataset: Iterable[ItemT],
    dataloader: IStreamingDataLoader[ItemT, ModelInputT],
    params: ModelParamsT | None = None,
    random_seed: int | None = None,
    device: Annotated[str | torch.device | None, WorkflowStepArgFlag.IGNORE] = None,
) -> Annotated[dict[str, float], WorkflowStepResultFlag.METRICS]:
    """Evaluate a PyTorch model on a dataset using the provided evaluator.

    This workflow step evaluates a PyTorch model on the provided dataset,
    computing metrics using the evaluator. Evaluation is performed in
    evaluation mode (no gradient computation).

    Args:
        model: PyTorch model to evaluate.
        evaluator: Evaluator to compute metrics.
        dataset: Dataset items for evaluation.
        dataloader: DataLoader to convert items to model inputs.
        params: Optional model parameters to use for evaluation.
        random_seed: Optional random seed for reproducibility.
        device: Optional device (e.g., `"cpu"`, `"cuda"`) to run evaluation on.

    Returns:
        Dictionary of computed evaluation metrics.

    Examples:
        >>> # Use in Python code
        >>> metrics = evaluate_torch_model(
        ...     model=trained_model,
        ...     evaluator=my_evaluator,
        ...     dataset=test_data,
        ...     dataloader=test_loader
        ... )

    """
    logger = use_step_logger(__name__)

    # Set random seed if provided
    if random_seed is not None:
        torch.manual_seed(random_seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(random_seed)

    # Evaluate model
    with torch.inference_mode(), use_device(device) as device:
        # Move model to device if specified
        model.to(device)

        # Set model to evaluation mode
        model.eval()

        # Reset evaluator state
        evaluator.reset()

        with (
            closing(dataloader(dataset)) as loader,
            progress(loader, desc="Evaluating model") as iterator,
        ):
            for inputs in iterator:
                inputs = move_to_device(inputs, device)
                output = model(inputs, params)
                evaluator.update(inputs, output)

    metrics = evaluator.compute()
    logger.info("Evaluation metrics: %s", ", ".join(f"{k}={v:.4f}" for k, v in metrics.items()))

    return metrics

predict

predict(
    dataset,
    dataloader,
    model,
    postprocessor,
    params=None,
    device=None,
    random_seed=None,
)

Generate predictions on a dataset using a PyTorch model.

This step applies a model to a dataset and postprocesses the outputs to generate final predictions.

PARAMETER DESCRIPTION
dataset

Dataset items for prediction.

TYPE: Iterable[ItemT]

dataloader

DataLoader to convert items to model inputs.

TYPE: IStreamingDataLoader[ItemT, ModelInputT]

model

PyTorch model to use for prediction.

TYPE: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT]

postprocessor

Function to convert model outputs to final results.

TYPE: Callable[[ModelInputT, ModelOutputT], Iterable[_ResultT]]

params

Optional model parameters to use for prediction.

TYPE: ModelParamsT | None DEFAULT: None

device

Optional device (e.g., "cpu", "cuda") to run prediction on.

TYPE: str | device | None DEFAULT: None

random_seed

Optional random seed for reproducibility.

TYPE: int | None DEFAULT: None

RETURNS DESCRIPTION
Iterator[_ResultT]

Iterator of prediction results.

Source code in src/formed/integrations/torch/workflow.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
@step("torch::predict")
@step("torch::predict_without_caching", cacheable=False)
def predict(
    dataset: Iterable[ItemT],
    dataloader: IStreamingDataLoader[ItemT, ModelInputT],
    model: BaseTorchModel[ModelInputT, ModelOutputT, ModelParamsT],
    postprocessor: Callable[[ModelInputT, ModelOutputT], Iterable[_ResultT]],
    params: ModelParamsT | None = None,
    device: Annotated[str | torch.device | None, WorkflowStepArgFlag.IGNORE] = None,
    random_seed: int | None = None,
) -> Iterator[_ResultT]:
    """Generate predictions on a dataset using a PyTorch model.

    This step applies a model to a dataset and postprocesses the outputs
    to generate final predictions.

    Args:
        dataset: Dataset items for prediction.
        dataloader: DataLoader to convert items to model inputs.
        model: PyTorch model to use for prediction.
        postprocessor: Function to convert model outputs to final results.
        params: Optional model parameters to use for prediction.
        device: Optional device (e.g., `"cpu"`, `"cuda"`) to run prediction on.
        random_seed: Optional random seed for reproducibility.

    Returns:
        Iterator of prediction results.
    """
    # Set random seed if provided
    if random_seed is not None:
        torch.manual_seed(random_seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(random_seed)

    with torch.inference_mode(), use_device(device) as device:
        # Move model to device if specified
        model.to(device)

        # Set model to evaluation mode
        model.eval()

        with (
            closing(dataloader(dataset)) as loader,
            progress(loader, desc="Predicting") as iterator,
        ):
            for inputs in iterator:
                inputs = move_to_device(inputs, device)
                output = model(inputs, params)
                yield from postprocessor(inputs, output)