Skip to content

ML

formed.integrations.ml.dataloader

BaseBatchSampler

Bases: Registrable, ABC, Generic[_InputT]

Abstract base class for batch samplers.

Batch samplers generate sequences of indices for batching data.

BasicBatchSampler

BasicBatchSampler(
    batch_size=1, shuffle=False, drop_last=False, seed=0
)

Bases: BaseBatchSampler[_InputT], Generic[_InputT]

Basic batch sampler that supports shuffling and dropping incomplete batches.

PARAMETER DESCRIPTION
batch_size

Number of samples per batch. Defaults to 1.

TYPE: int DEFAULT: 1

shuffle

Whether to shuffle the data before batching. Defaults to False.

TYPE: bool DEFAULT: False

drop_last

Whether to drop the last incomplete batch. Defaults to False.

TYPE: bool DEFAULT: False

seed

Random seed for shuffling. Defaults to 0.

TYPE: int DEFAULT: 0

Examples:

>>> sampler = BasicBatchSampler(batch_size=32, shuffle=True)
>>> dataset = list(range(100))
>>> batches = list(sampler(dataset))
>>> len(batches)
4
Source code in src/formed/integrations/ml/dataloader.py
55
56
57
58
59
60
61
62
63
64
65
def __init__(
    self,
    batch_size: int = 1,
    shuffle: bool = False,
    drop_last: bool = False,
    seed: int = 0,
) -> None:
    self._batch_size = batch_size
    self._shuffle = shuffle
    self._drop_last = drop_last
    self._rng = random.Random(seed)

SizeOrderedBucketBatchSampler

SizeOrderedBucketBatchSampler(
    attribute,
    batch_size=1,
    shuffle=False,
    drop_last=False,
    seed=0,
)

Bases: BaseBatchSampler[_InputT], Generic[_InputT]

Batch sampler that orders data by size before batching.

This sampler sorts the dataset based on a specified size attribute, then creates batches of a given size. It can optionally shuffle the batches after creation.

PARAMETER DESCRIPTION
attribute

Attribute name or callable to determine the size of each item.

TYPE: str | Callable[[_InputT], Sized]

batch_size

Number of samples per batch. Defaults to 1.

TYPE: int DEFAULT: 1

shuffle

Whether to shuffle the batches after creation. Defaults to False.

TYPE: bool DEFAULT: False

drop_last

Whether to drop the last incomplete batch. Defaults to False.

TYPE: bool DEFAULT: False

seed

Random seed for shuffling. Defaults to 0.

TYPE: int DEFAULT: 0

Examples:

>>> sampler = SizeOrderedBatchSampler(size_attr="length", batch_size=16, shuffle=True)
>>> dataset = [{"length": i} for i in range(100)]
>>> batches = list(sampler(dataset))
>>> len(batches)
7
Source code in src/formed/integrations/ml/dataloader.py
115
116
117
118
119
120
121
122
123
124
125
126
127
def __init__(
    self,
    attribute: str | Callable[[_InputT], Sized],
    batch_size: int = 1,
    shuffle: bool = False,
    drop_last: bool = False,
    seed: int = 0,
) -> None:
    self._attribute = attribute if callable(attribute) else partial(xgetattr, name=attribute)
    self._batch_size = batch_size
    self._shuffle = shuffle
    self._drop_last = drop_last
    self._rng = random.Random(seed)

BatchIterator

BatchIterator(data, sampler, collator)

Bases: Generic[_InputT, _BatchT]

Iterator that generates batches from a dataset using a sampler and collator.

PARAMETER DESCRIPTION
data

The dataset to iterate over.

TYPE: Sequence[_InputT]

sampler

Batch sampler that generates batch indices.

TYPE: BaseBatchSampler

collator

Function that collates a sequence of items into a batch.

TYPE: Callable[[Sequence[_InputT]], _BatchT]

Examples:

>>> def collator(batch):
...     return [x * 2 for x in batch]
>>> sampler = BasicBatchSampler(batch_size=2)
>>> iterator = BatchIterator([1, 2, 3, 4], sampler, collator)
>>> list(iterator)
[[2, 4], [6, 8]]
Source code in src/formed/integrations/ml/dataloader.py
174
175
176
177
178
179
180
181
182
183
def __init__(
    self,
    data: Sequence[_InputT],
    sampler: BaseBatchSampler,
    collator: Callable[[Sequence[_InputT]], _BatchT],
) -> None:
    self._data = data
    self._collator = collator
    self._sampler = sampler
    self._iterator: Iterator[Sequence[int]] | None = None

DataLoader

DataLoader(sampler, collator, buffer_size=0)

Bases: Generic[_InputT, _BatchT]

Data loader that creates batched iterators from datasets.

The DataLoader combines a batch sampler and a collator function to create batched data iterators. It optionally supports buffering using a separate process to prefetch batches in the background, which can improve throughput when batch collation is expensive.

PARAMETER DESCRIPTION
sampler

Batch sampler that generates batch indices.

TYPE: BaseBatchSampler

collator

Function that collates a sequence of items into a batch.

TYPE: Callable[[Sequence[_InputT]], _BatchT]

buffer_size

Size of the prefetch buffer. If 0, buffering is disabled. If > 0, batches are prepared in a background process. Defaults to 0.

TYPE: int DEFAULT: 0

Examples:

Basic usage without buffering:

>>> def collator(batch):
...     return [x * 2 for x in batch]
>>> sampler = BasicBatchSampler(batch_size=4, shuffle=True)
>>> loader = DataLoader(sampler=sampler, collator=collator)
>>> dataset = list(range(10))
>>> batches = list(loader(dataset))

With buffering for better performance:

>>> loader = DataLoader(
...     sampler=sampler,
...     collator=collator,
...     buffer_size=10,  # Prefetch up to 10 batches
... )
>>> from formed.common.ctxutils import closing
>>> with closing(loader(dataset)) as batches:
...     for batch in batches:
...         # Process batch
...         pass
Note

When using buffering (buffer_size > 0), the collator function and any objects it references must be picklable, as they will be passed to a background process. Also, it's recommended to use the loader with a context manager (closing) to ensure proper cleanup of the background process.

Source code in src/formed/integrations/ml/dataloader.py
238
239
240
241
242
243
244
245
246
def __init__(
    self,
    sampler: BaseBatchSampler,
    collator: Callable[[Sequence[_InputT]], _BatchT],
    buffer_size: int = 0,
) -> None:
    self._collator = collator
    self._sampler = sampler
    self._buffer_size = buffer_size

formed.integrations.ml.metrics

Metrics for evaluating machine learning models.

This module provides a comprehensive set of metrics for evaluating classification, regression, and ranking models. All metrics follow a common interface with reset, update, and compute methods.

Key Components

Base Classes:

  • BaseMetric: Abstract base for all metrics
  • BinaryClassificationMetric: Base for binary classification metrics
  • MulticlassClassificationMetric: Base for multiclass metrics
  • MultilabelClassificationMetric: Base for multilabel metrics
  • RegressionMetric: Base for regression metrics
  • RankingMetric: Base for ranking metrics

Classification Metrics:

  • BinaryAccuracy: Binary classification accuracy
  • BinaryFBeta: Binary F-beta score (precision, recall, F1)
  • BinaryROCAUC: Binary ROC AUC (requires scores)
  • BinaryPRAUC: Binary PR AUC (Precision-Recall curve, requires scores)
  • MulticlassAccuracy: Multiclass accuracy (micro/macro)
  • MulticlassFBeta: Multiclass F-beta (micro/macro)
  • MultilabelAccuracy: Multilabel accuracy
  • MultilabelFBeta: Multilabel F-beta

Regression Metrics:

  • MeanAbsoluteError: MAE metric
  • MeanSquaredError: MSE metric

Ranking Metrics:

  • MeanAveragePrecision: MAP metric
  • NDCG: Normalized Discounted Cumulative Gain

Utility Metrics:

  • Average: Simple averaging metric
  • EmptyMetric: No-op metric
Features
  • Stateful metrics with accumulation across batches
  • Support for micro and macro averaging
  • Flexible label types (int, str, bool)
  • Registrable for configuration-based instantiation

Examples:

>>> from formed.integrations.ml.metrics import MulticlassAccuracy, ClassificationInput
>>>
>>> # Create metric
>>> metric = MulticlassAccuracy(average="macro")
>>>
>>> # Update with batch
>>> inputs = ClassificationInput(
...     predictions=[0, 1, 2, 1],
...     targets=[0, 1, 1, 1]
... )
>>> metric.update(inputs)
>>>
>>> # Compute final metrics
>>> results = metric.compute()
>>> print(results)  # {"accuracy": 0.75}
>>>
>>> # Reset for next evaluation
>>> metric.reset()

BaseMetric

Bases: Registrable, Generic[_T], ABC

Abstract base class for all metrics.

Metrics are stateful objects that accumulate predictions and targets across multiple batches, then compute aggregate statistics.

CLASS TYPE PARAMETER DESCRIPTION
_T

Type of input data for this metric.

Examples:

>>> @BaseMetric.register("my_metric")
... class MyMetric(BaseMetric[MyInputType]):
...     def reset(self):
...         self._state = 0
...     def update(self, inputs):
...         self._state += process(inputs)
...     def compute(self):
...         return {"metric": self._state}

reset abstractmethod

reset()

Reset internal state for a new evaluation.

This should clear all accumulated statistics.

Source code in src/formed/integrations/ml/metrics.py
108
109
110
111
112
113
114
115
@abc.abstractmethod
def reset(self) -> None:
    """Reset internal state for a new evaluation.

    This should clear all accumulated statistics.

    """
    raise NotImplementedError()

update abstractmethod

update(inputs)

Update internal state with a batch of predictions.

PARAMETER DESCRIPTION
inputs

Batch of predictions and targets.

TYPE: _T

Source code in src/formed/integrations/ml/metrics.py
117
118
119
120
121
122
123
124
125
@abc.abstractmethod
def update(self, inputs: _T) -> None:
    """Update internal state with a batch of predictions.

    Args:
        inputs: Batch of predictions and targets.

    """
    raise NotImplementedError()

compute abstractmethod

compute()

Compute metrics from accumulated state.

RETURNS DESCRIPTION
dict[str, float]

Dictionary mapping metric names to values.

Source code in src/formed/integrations/ml/metrics.py
127
128
129
130
131
132
133
134
135
@abc.abstractmethod
def compute(self) -> dict[str, float]:
    """Compute metrics from accumulated state.

    Returns:
        Dictionary mapping metric names to values.

    """
    raise NotImplementedError()

clone

clone()

Create a deep copy of this metric. Returns: A new instance of the metric with the same internal state.

Source code in src/formed/integrations/ml/metrics.py
150
151
152
153
154
155
def clone(self) -> Self:
    """Create a deep copy of this metric.
    Returns:
        A new instance of the metric with the same internal state.
    """
    return copy.deepcopy(self)

EmptyMetric

Bases: BaseMetric[Any]

No-op metric that does nothing.

This metric can be used as a placeholder when no evaluation is needed.

reset

reset()
Source code in src/formed/integrations/ml/metrics.py
166
167
def reset(self) -> None:
    pass

update

update(inputs)
Source code in src/formed/integrations/ml/metrics.py
169
170
def update(self, inputs: Any) -> None:
    pass

compute

compute()
Source code in src/formed/integrations/ml/metrics.py
172
173
def compute(self) -> dict[str, float]:
    return {}

clone

clone()

Create a deep copy of this metric. Returns: A new instance of the metric with the same internal state.

Source code in src/formed/integrations/ml/metrics.py
150
151
152
153
154
155
def clone(self) -> Self:
    """Create a deep copy of this metric.
    Returns:
        A new instance of the metric with the same internal state.
    """
    return copy.deepcopy(self)

Average

Average(name='average')

Bases: BaseMetric[Sequence[float]]

Simple averaging metric for numeric values.

Computes the mean of all values seen across batches.

PARAMETER DESCRIPTION
name

Name for the metric in output dictionary.

TYPE: str DEFAULT: 'average'

Examples:

>>> metric = Average(name="loss")
>>> metric.update([1.0, 2.0, 3.0])
>>> metric.update([4.0, 5.0])
>>> metric.compute()  # {"loss": 3.0}
Source code in src/formed/integrations/ml/metrics.py
193
194
195
196
def __init__(self, name: str = "average") -> None:
    self._name = name
    self._total = 0.0
    self._count = 0

reset

reset()
Source code in src/formed/integrations/ml/metrics.py
198
199
200
def reset(self) -> None:
    self._total = 0.0
    self._count = 0

update

update(inputs)
Source code in src/formed/integrations/ml/metrics.py
202
203
204
def update(self, inputs: Sequence[float]) -> None:
    self._total += sum(inputs)
    self._count += len(inputs)

compute

compute()
Source code in src/formed/integrations/ml/metrics.py
206
207
def compute(self) -> dict[str, float]:
    return {self._name: self._total / self._count if self._count > 0 else 0.0}

clone

clone()

Create a deep copy of this metric. Returns: A new instance of the metric with the same internal state.

Source code in src/formed/integrations/ml/metrics.py
150
151
152
153
154
155
def clone(self) -> Self:
    """Create a deep copy of this metric.
    Returns:
        A new instance of the metric with the same internal state.
    """
    return copy.deepcopy(self)

ClassificationInput dataclass

ClassificationInput(predictions, targets)

Bases: Generic[_T]

Input data for classification metrics.

ATTRIBUTE DESCRIPTION
predictions

Sequence of predicted labels.

TYPE: Sequence[_T]

targets

Sequence of ground truth labels.

TYPE: Sequence[_T]

predictions instance-attribute

predictions

targets instance-attribute

targets

BinaryClassificationInput dataclass

BinaryClassificationInput(
    predictions, targets, scores=None
)

Bases: Generic[BinaryLabelT]

Input data for binary classification metrics that require probability scores.

ATTRIBUTE DESCRIPTION
predictions

Sequence of predicted labels.

TYPE: Sequence[BinaryLabelT]

scores

Sequence of prediction scores (probabilities for positive class).

TYPE: Sequence[float] | None

targets

Sequence of ground truth labels.

TYPE: Sequence[BinaryLabelT]

predictions instance-attribute

predictions

targets instance-attribute

targets

scores class-attribute instance-attribute

scores = None

BinaryClassificationMetric

Bases: BaseMetric[BinaryClassificationInput[BinaryLabelT]], Generic[BinaryLabelT]

Base class for binary classification metrics.

Binary classification metrics work with two classes (0 and1`, orTrue/False`).

CLASS TYPE PARAMETER DESCRIPTION
BinaryLabelT

Type of labels (int, bool, etc.).

Input class-attribute instance-attribute

reset abstractmethod

reset()

Reset internal state for a new evaluation.

This should clear all accumulated statistics.

Source code in src/formed/integrations/ml/metrics.py
108
109
110
111
112
113
114
115
@abc.abstractmethod
def reset(self) -> None:
    """Reset internal state for a new evaluation.

    This should clear all accumulated statistics.

    """
    raise NotImplementedError()

update abstractmethod

update(inputs)

Update internal state with a batch of predictions.

PARAMETER DESCRIPTION
inputs

Batch of predictions and targets.

TYPE: _T

Source code in src/formed/integrations/ml/metrics.py
117
118
119
120
121
122
123
124
125
@abc.abstractmethod
def update(self, inputs: _T) -> None:
    """Update internal state with a batch of predictions.

    Args:
        inputs: Batch of predictions and targets.

    """
    raise NotImplementedError()

compute abstractmethod

compute()

Compute metrics from accumulated state.

RETURNS DESCRIPTION
dict[str, float]

Dictionary mapping metric names to values.

Source code in src/formed/integrations/ml/metrics.py
127
128
129
130
131
132
133
134
135
@abc.abstractmethod
def compute(self) -> dict[str, float]:
    """Compute metrics from accumulated state.

    Returns:
        Dictionary mapping metric names to values.

    """
    raise NotImplementedError()

clone

clone()

Create a deep copy of this metric. Returns: A new instance of the metric with the same internal state.

Source code in src/formed/integrations/ml/metrics.py
150
151
152
153
154
155
def clone(self) -> Self:
    """Create a deep copy of this metric.
    Returns:
        A new instance of the metric with the same internal state.
    """
    return copy.deepcopy(self)

BinaryAccuracy

BinaryAccuracy()

Bases: BinaryClassificationMetric[BinaryLabelT], Generic[BinaryLabelT]

Binary classification accuracy metric.

Computes the fraction of correct predictions.

Examples:

>>> metric = BinaryAccuracy()
>>> inputs = ClassificationInput(
...     predictions=[1, 0, 1, 1],
...     targets=[1, 0, 0, 1]
... )
>>> metric.update(inputs)
>>> metric.compute()  # {"accuracy": 0.75}
Source code in src/formed/integrations/ml/metrics.py
271
272
273
def __init__(self) -> None:
    self._correct = 0
    self._total = 0

Input class-attribute instance-attribute

reset

reset()
Source code in src/formed/integrations/ml/metrics.py
275
276
277
def reset(self) -> None:
    self._correct = 0
    self._total = 0

update

update(inputs)
Source code in src/formed/integrations/ml/metrics.py
279
280
281
282
283
284
285
286
287
def update(self, inputs: BinaryClassificationInput[BinaryLabelT]) -> None:
    predictions = inputs.predictions
    targets = inputs.targets
    assert len(predictions) == len(targets), "Predictions and targets must have the same length"

    for pred, target in zip(predictions, targets):
        if pred == target:
            self._correct += 1
        self._total += 1

compute

compute()
Source code in src/formed/integrations/ml/metrics.py
289
290
291
def compute(self) -> dict[str, float]:
    accuracy = self._correct / self._total if self._total > 0 else 0.0
    return {"accuracy": accuracy}

clone

clone()

Create a deep copy of this metric. Returns: A new instance of the metric with the same internal state.

Source code in src/formed/integrations/ml/metrics.py
150
151
152
153
154
155
def clone(self) -> Self:
    """Create a deep copy of this metric.
    Returns:
        A new instance of the metric with the same internal state.
    """
    return copy.deepcopy(self)

BinaryFBeta

BinaryFBeta(beta=1.0)

Bases: BinaryClassificationMetric[BinaryLabelT], Generic[BinaryLabelT]

Binary F-beta score with precision and recall.

Computes F-beta score, precision, and recall for binary classification. F-beta is the weighted harmonic mean of precision and recall, where beta controls the weight of recall relative to precision.

PARAMETER DESCRIPTION
beta

Weight of recall relative to precision. Common values: - 1.0: F1 score (balanced) - 0.5: F0.5 (emphasizes precision) - 2.0: F2 (emphasizes recall)

TYPE: float DEFAULT: 1.0

RETURNS DESCRIPTION

Dictionary with "fbeta", "precision", and "recall" metrics.

Examples:

>>> # F1 score (beta=1.0)
>>> metric = BinaryFBeta(beta=1.0)
>>> inputs = ClassificationInput(
...     predictions=[1, 1, 0, 1],
...     targets=[1, 0, 0, 1]
... )
>>> metric.update(inputs)
>>> result = metric.compute()
>>> # {"fbeta": 0.67, "precision": 0.67, "recall": 1.0}
Source code in src/formed/integrations/ml/metrics.py
325
326
327
328
329
def __init__(self, beta: float = 1.0) -> None:
    self._beta = beta
    self._true_positive = 0
    self._false_positive = 0
    self._false_negative = 0

Input class-attribute instance-attribute

reset

reset()
Source code in src/formed/integrations/ml/metrics.py
331
332
333
334
def reset(self) -> None:
    self._true_positive = 0
    self._false_positive = 0
    self._false_negative = 0

update

update(inputs)
Source code in src/formed/integrations/ml/metrics.py
336
337
338
339
340
341
342
343
344
345
346
347
def update(self, inputs: BinaryClassificationInput[BinaryLabelT]) -> None:
    predictions = inputs.predictions
    targets = inputs.targets
    assert len(predictions) == len(targets), "Predictions and targets must have the same length"

    for pred, target in zip(predictions, targets):
        if pred == target == 1:
            self._true_positive += 1
        elif pred == 1 and target == 0:
            self._false_positive += 1
        elif pred == 0 and target == 1:
            self._false_negative += 1

compute

compute()
Source code in src/formed/integrations/ml/metrics.py
349
350
351
352
353
354
355
356
357
358
359
360
361
362
def compute(self) -> dict[str, float]:
    beta_sq = self._beta**2
    precision_denominator = self._true_positive + self._false_positive
    recall_denominator = self._true_positive + self._false_negative

    precision = self._true_positive / precision_denominator if precision_denominator > 0 else 0.0
    recall = self._true_positive / recall_denominator if recall_denominator > 0 else 0.0

    if precision + recall == 0:
        fbeta = 0.0
    else:
        fbeta = (1 + beta_sq) * (precision * recall) / (beta_sq * precision + recall)

    return {"fbeta": fbeta, "precision": precision, "recall": recall}

clone

clone()

Create a deep copy of this metric. Returns: A new instance of the metric with the same internal state.

Source code in src/formed/integrations/ml/metrics.py
150
151
152
153
154
155
def clone(self) -> Self:
    """Create a deep copy of this metric.
    Returns:
        A new instance of the metric with the same internal state.
    """
    return copy.deepcopy(self)

BinaryROCAUC

BinaryROCAUC()

Bases: BinaryClassificationMetric[BinaryLabelT], Generic[BinaryLabelT]

Binary ROC AUC (Area Under the Receiver Operating Characteristic Curve) metric.

Computes the area under the ROC curve, which measures the model's ability to distinguish between positive and negative classes across all thresholds. ROC AUC ranges from 0 to 1, where 0.5 represents random guessing and 1.0 represents perfect classification.

RETURNS DESCRIPTION

Dictionary with "roc_auc" metric.

Examples:

>>> metric = BinaryROCAUC()
>>> inputs = BinaryClassificationInputWithScores(
...     predictions=[1, 1, 0, 1],
...     scores=[0.9, 0.8, 0.3, 0.7],
...     targets=[1, 0, 0, 1]
... )
>>> metric.update(inputs)
>>> result = metric.compute()
>>> # {"roc_auc": 0.75}
Source code in src/formed/integrations/ml/metrics.py
391
392
393
def __init__(self) -> None:
    self._scores: list[float] = []
    self._targets: list[int] = []

Input class-attribute instance-attribute

reset

reset()
Source code in src/formed/integrations/ml/metrics.py
395
396
397
def reset(self) -> None:
    self._scores = []
    self._targets = []

update

update(inputs)
Source code in src/formed/integrations/ml/metrics.py
399
400
401
402
403
404
405
406
407
408
def update(self, inputs: BinaryClassificationInput[BinaryLabelT]) -> None:
    assert inputs.scores is not None, "Scores are required for ROC AUC computation"

    scores = inputs.scores
    targets = inputs.targets
    assert len(scores) == len(targets), "Scores and targets must have the same length"

    for score, target in zip(scores, targets):
        self._scores.append(score)
        self._targets.append(1 if target == 1 or target is True else 0)

compute

compute()
Source code in src/formed/integrations/ml/metrics.py
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
def compute(self) -> dict[str, float]:
    if not self._scores:
        return {"roc_auc": 0.0}

    # Count total positives and negatives
    n_pos = sum(self._targets)
    n_neg = len(self._targets) - n_pos

    if n_pos == 0 or n_neg == 0:
        return {"roc_auc": 0.0}

    # Sort by scores in descending order, with ties broken by target (negatives first)
    sorted_pairs = sorted(zip(self._scores, self._targets), key=lambda x: (-x[0], x[1]))

    # Calculate ROC curve points and AUC
    tp = 0
    fp = 0
    prev_tp = 0
    prev_fp = 0
    prev_score = float("inf")
    auc = 0.0

    for score, target in sorted_pairs:
        # When score changes, add area for previous threshold
        if score != prev_score:
            # Add trapezoid area: width * average height
            auc += (fp - prev_fp) * (tp + prev_tp) / 2.0
            prev_tp = tp
            prev_fp = fp
            prev_score = score

        if target == 1:
            tp += 1
        else:
            fp += 1

    # Add final trapezoid
    auc += (fp - prev_fp) * (tp + prev_tp) / 2.0

    # Normalize by total area
    auc /= n_pos * n_neg

    return {"roc_auc": auc}

clone

clone()

Create a deep copy of this metric. Returns: A new instance of the metric with the same internal state.

Source code in src/formed/integrations/ml/metrics.py
150
151
152
153
154
155
def clone(self) -> Self:
    """Create a deep copy of this metric.
    Returns:
        A new instance of the metric with the same internal state.
    """
    return copy.deepcopy(self)

BinaryPRAUC

BinaryPRAUC()

Bases: BinaryClassificationMetric[BinaryLabelT], Generic[BinaryLabelT]

Binary PR AUC (Area Under the Precision-Recall Curve) metric.

Computes the area under the Precision-Recall curve, which plots precision (y-axis) against recall (x-axis) at different classification thresholds. This metric is particularly useful for imbalanced datasets where ROC AUC might be overly optimistic.

Unlike ROC AUC which uses false positive rate, PR AUC focuses on the positive class performance, making it more informative when the positive class is rare.

RETURNS DESCRIPTION

Dictionary with "pr_auc" metric.

Examples:

>>> metric = BinaryPRAUC()
>>> inputs = BinaryClassificationInputWithScores(
...     predictions=[1, 1, 0, 1],
...     scores=[0.9, 0.8, 0.3, 0.7],
...     targets=[1, 0, 0, 1]
... )
>>> metric.update(inputs)
>>> result = metric.compute()
>>> # {"pr_auc": 0.833...}
Source code in src/formed/integrations/ml/metrics.py
485
486
487
def __init__(self) -> None:
    self._scores: list[float] = []
    self._targets: list[int] = []

Input class-attribute instance-attribute

reset

reset()
Source code in src/formed/integrations/ml/metrics.py
489
490
491
def reset(self) -> None:
    self._scores = []
    self._targets = []

update

update(inputs)
Source code in src/formed/integrations/ml/metrics.py
493
494
495
496
497
498
499
500
501
502
def update(self, inputs: BinaryClassificationInput[BinaryLabelT]) -> None:
    assert inputs.scores is not None, "Scores are required for PR AUC computation"

    scores = inputs.scores
    targets = inputs.targets
    assert len(scores) == len(targets), "Scores and targets must have the same length"

    for score, target in zip(scores, targets):
        self._scores.append(score)
        self._targets.append(1 if target == 1 or target is True else 0)

compute

compute()
Source code in src/formed/integrations/ml/metrics.py
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
def compute(self) -> dict[str, float]:
    if not self._scores:
        return {"pr_auc": 0.0}

    # Count total positives
    n_pos = sum(self._targets)
    if n_pos == 0:
        return {"pr_auc": 0.0}

    # Sort by scores in descending order (stable sort to preserve order for ties)
    # mergesort is stable in Python
    indexed_pairs = [(self._scores[i], self._targets[i], i) for i in range(len(self._scores))]
    indexed_pairs.sort(key=lambda x: -x[0], reverse=False)

    # Group by unique scores to handle ties correctly
    score_groups: list[list[tuple[float, int, int]]] = []
    current_group: list[tuple[float, int, int]] = []
    current_score = None

    for score, target, idx in indexed_pairs:
        if current_score is not None and score != current_score:
            score_groups.append(current_group)
            current_group = []
        current_group.append((score, target, idx))
        current_score = score

    if current_group:
        score_groups.append(current_group)

    # Calculate average precision: sum of precisions at each positive sample divided by n_pos
    # For tied scores, we use the average precision across the group
    tp = 0
    fp = 0
    ap_sum = 0.0

    for group in score_groups:
        # Count positives and negatives in this group
        group_positives = sum(1 for _, target, _ in group if target == 1)
        group_negatives = len(group) - group_positives

        if group_positives > 0:
            # sklearn uses: precision at the END of processing all samples in the group
            tp_after_group = tp + group_positives
            fp_after_group = fp + group_negatives

            # Add precision for each positive (using precision at end of group)
            precision_at_group = tp_after_group / (tp_after_group + fp_after_group)
            ap_sum += precision_at_group * group_positives

        tp += group_positives
        fp += group_negatives

    return {"pr_auc": ap_sum / n_pos}

clone

clone()

Create a deep copy of this metric. Returns: A new instance of the metric with the same internal state.

Source code in src/formed/integrations/ml/metrics.py
150
151
152
153
154
155
def clone(self) -> Self:
    """Create a deep copy of this metric.
    Returns:
        A new instance of the metric with the same internal state.
    """
    return copy.deepcopy(self)

MulticlassClassificationMetric

Bases: BaseMetric[ClassificationInput[LabelT]], Generic[LabelT]

Base class for multiclass classification metrics.

Multiclass metrics work with any number of classes and support both micro and macro averaging strategies.

CLASS TYPE PARAMETER DESCRIPTION
LabelT

Type of labels (int, str, etc.).

Input class-attribute instance-attribute

reset abstractmethod

reset()

Reset internal state for a new evaluation.

This should clear all accumulated statistics.

Source code in src/formed/integrations/ml/metrics.py
108
109
110
111
112
113
114
115
@abc.abstractmethod
def reset(self) -> None:
    """Reset internal state for a new evaluation.

    This should clear all accumulated statistics.

    """
    raise NotImplementedError()

update abstractmethod

update(inputs)

Update internal state with a batch of predictions.

PARAMETER DESCRIPTION
inputs

Batch of predictions and targets.

TYPE: _T

Source code in src/formed/integrations/ml/metrics.py
117
118
119
120
121
122
123
124
125
@abc.abstractmethod
def update(self, inputs: _T) -> None:
    """Update internal state with a batch of predictions.

    Args:
        inputs: Batch of predictions and targets.

    """
    raise NotImplementedError()

compute abstractmethod

compute()

Compute metrics from accumulated state.

RETURNS DESCRIPTION
dict[str, float]

Dictionary mapping metric names to values.

Source code in src/formed/integrations/ml/metrics.py
127
128
129
130
131
132
133
134
135
@abc.abstractmethod
def compute(self) -> dict[str, float]:
    """Compute metrics from accumulated state.

    Returns:
        Dictionary mapping metric names to values.

    """
    raise NotImplementedError()

clone

clone()

Create a deep copy of this metric. Returns: A new instance of the metric with the same internal state.

Source code in src/formed/integrations/ml/metrics.py
150
151
152
153
154
155
def clone(self) -> Self:
    """Create a deep copy of this metric.
    Returns:
        A new instance of the metric with the same internal state.
    """
    return copy.deepcopy(self)

MulticlassAccuracy

MulticlassAccuracy(average='micro')

Bases: MulticlassClassificationMetric[LabelT], Generic[LabelT]

Multiclass classification accuracy with averaging strategies.

Computes accuracy for multiclass classification with support for micro (overall accuracy) and macro (per-class average) strategies.

PARAMETER DESCRIPTION
average

Averaging strategy: - "micro": Overall accuracy across all samples - "macro": Average of per-class accuracies

TYPE: Literal['micro', 'macro'] DEFAULT: 'micro'

Examples:

>>> # Micro averaging (overall accuracy)
>>> metric = MulticlassAccuracy(average="micro")
>>> inputs = ClassificationInput(
...     predictions=[0, 1, 2, 1],
...     targets=[0, 1, 1, 1]
... )
>>> metric.update(inputs)
>>> metric.compute()  # {"accuracy": 0.75}
>>>
>>> # Macro averaging (per-class average)
>>> metric = MulticlassAccuracy(average="macro")
>>> metric.update(inputs)
>>> metric.compute()  # Average of class-wise accuracies
Source code in src/formed/integrations/ml/metrics.py
603
604
605
606
def __init__(self, average: Literal["micro", "macro"] = "micro") -> None:
    self._average = average
    self._correct: dict[LabelT, int] = defaultdict(int)
    self._total: dict[LabelT, int] = defaultdict(int)

Input class-attribute instance-attribute

reset

reset()
Source code in src/formed/integrations/ml/metrics.py
608
609
610
def reset(self) -> None:
    self._correct = defaultdict(int)
    self._total = defaultdict(int)

update

update(inputs)
Source code in src/formed/integrations/ml/metrics.py
612
613
614
615
616
617
618
619
620
def update(self, inputs: ClassificationInput[LabelT]) -> None:
    predictions = inputs.predictions
    targets = inputs.targets
    assert len(predictions) == len(targets), "Predictions and targets must have the same length"

    for pred, target in zip(predictions, targets):
        if pred == target:
            self._correct[target] += 1
        self._total[target] += 1

compute

compute()
Source code in src/formed/integrations/ml/metrics.py
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
def compute(self) -> dict[str, float]:
    if self._average == "micro":
        total_correct = sum(self._correct.values())
        total_count = sum(self._total.values())
        accuracy = total_correct / total_count if total_count > 0 else 0.0
        return {"accuracy": accuracy}
    elif self._average == "macro":
        accuracies = []
        for label in self._total.keys():
            correct = self._correct[label]
            total = self._total[label]
            accuracies.append(correct / total if total > 0 else 0.0)
        macro_accuracy = sum(accuracies) / len(accuracies) if accuracies else 0.0
        return {"accuracy": macro_accuracy}
    else:
        raise ValueError(f"Unknown average type: {self._average}")

clone

clone()

Create a deep copy of this metric. Returns: A new instance of the metric with the same internal state.

Source code in src/formed/integrations/ml/metrics.py
150
151
152
153
154
155
def clone(self) -> Self:
    """Create a deep copy of this metric.
    Returns:
        A new instance of the metric with the same internal state.
    """
    return copy.deepcopy(self)

MulticlassFBeta

MulticlassFBeta(beta=1.0, average='micro')

Bases: MulticlassClassificationMetric[LabelT], Generic[LabelT]

Multiclass F-beta score with precision and recall.

Computes F-beta, precision, and recall for multiclass classification with support for micro and macro averaging.

PARAMETER DESCRIPTION
beta

Weight of recall relative to precision (default: 1.0 for F1).

TYPE: float DEFAULT: 1.0

average

Averaging strategy: - "micro": Compute globally across all classes - "macro": Compute per-class then average

TYPE: Literal['micro', 'macro'] DEFAULT: 'micro'

RETURNS DESCRIPTION

Dictionary with "fbeta", "precision", and "recall" metrics.

Examples:

>>> metric = MulticlassFBeta(beta=1.0, average="macro")
>>> inputs = ClassificationInput(
...     predictions=[0, 1, 2, 1],
...     targets=[0, 1, 1, 1]
... )
>>> metric.update(inputs)
>>> metric.compute()
>>> # {"fbeta": ..., "precision": ..., "recall": ...}
Source code in src/formed/integrations/ml/metrics.py
669
670
671
672
673
674
def __init__(self, beta: float = 1.0, average: Literal["micro", "macro"] = "micro") -> None:
    self._beta = beta
    self._average = average
    self._true_positive: dict[LabelT, int] = defaultdict(int)
    self._false_positive: dict[LabelT, int] = defaultdict(int)
    self._false_negative: dict[LabelT, int] = defaultdict(int)

Input class-attribute instance-attribute

reset

reset()
Source code in src/formed/integrations/ml/metrics.py
676
677
678
679
def reset(self) -> None:
    self._true_positive = defaultdict(int)
    self._false_positive = defaultdict(int)
    self._false_negative = defaultdict(int)

update

update(inputs)
Source code in src/formed/integrations/ml/metrics.py
681
682
683
684
685
686
687
688
689
690
691
def update(self, inputs: ClassificationInput[LabelT]) -> None:
    predictions = inputs.predictions
    targets = inputs.targets
    assert len(predictions) == len(targets), "Predictions and targets must have the same length"

    for pred, target in zip(predictions, targets):
        if pred == target:
            self._true_positive[target] += 1
        else:
            self._false_positive[pred] += 1
            self._false_negative[target] += 1

compute

compute()
Source code in src/formed/integrations/ml/metrics.py
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
def compute(self) -> dict[str, float]:
    beta_sq = self._beta**2

    if self._average == "micro":
        total_true_positive = sum(self._true_positive.values())
        total_false_positive = sum(self._false_positive.values())
        total_false_negative = sum(self._false_negative.values())

        precision_denominator = total_true_positive + total_false_positive
        recall_denominator = total_true_positive + total_false_negative

        precision = total_true_positive / precision_denominator if precision_denominator > 0 else 0.0
        recall = total_true_positive / recall_denominator if recall_denominator > 0 else 0.0

        if precision + recall == 0:
            fbeta = 0.0
        else:
            fbeta = (1 + beta_sq) * (precision * recall) / (beta_sq * precision + recall)

        return {"fbeta": fbeta, "precision": precision, "recall": recall}

    elif self._average == "macro":
        fbetas = []
        precisions = []
        recalls = []

        for label in (
            set(self._true_positive.keys()).union(self._false_positive.keys()).union(self._false_negative.keys())
        ):
            tp = self._true_positive[label]
            fp = self._false_positive[label]
            fn = self._false_negative[label]

            precision_denominator = tp + fp
            recall_denominator = tp + fn

            precision = tp / precision_denominator if precision_denominator > 0 else 0.0
            recall = tp / recall_denominator if recall_denominator > 0 else 0.0

            if precision + recall == 0:
                fbeta = 0.0
            else:
                fbeta = (1 + beta_sq) * (precision * recall) / (beta_sq * precision + recall)

            fbetas.append(fbeta)
            precisions.append(precision)
            recalls.append(recall)
        macro_fbeta = sum(fbetas) / len(fbetas) if fbetas else 0.0
        macro_precision = sum(precisions) / len(precisions) if precisions else 0.0
        macro_recall = sum(recalls) / len(recalls) if recalls else 0.0
        return {"fbeta": macro_fbeta, "precision": macro_precision, "recall": macro_recall}
    else:
        raise ValueError(f"Unknown average type: {self._average}")

clone

clone()

Create a deep copy of this metric. Returns: A new instance of the metric with the same internal state.

Source code in src/formed/integrations/ml/metrics.py
150
151
152
153
154
155
def clone(self) -> Self:
    """Create a deep copy of this metric.
    Returns:
        A new instance of the metric with the same internal state.
    """
    return copy.deepcopy(self)

MultilabelClassificationMetric

Bases: BaseMetric[ClassificationInput[Sequence[LabelT]]], Generic[LabelT]

Input class-attribute instance-attribute

reset abstractmethod

reset()

Reset internal state for a new evaluation.

This should clear all accumulated statistics.

Source code in src/formed/integrations/ml/metrics.py
108
109
110
111
112
113
114
115
@abc.abstractmethod
def reset(self) -> None:
    """Reset internal state for a new evaluation.

    This should clear all accumulated statistics.

    """
    raise NotImplementedError()

update abstractmethod

update(inputs)

Update internal state with a batch of predictions.

PARAMETER DESCRIPTION
inputs

Batch of predictions and targets.

TYPE: _T

Source code in src/formed/integrations/ml/metrics.py
117
118
119
120
121
122
123
124
125
@abc.abstractmethod
def update(self, inputs: _T) -> None:
    """Update internal state with a batch of predictions.

    Args:
        inputs: Batch of predictions and targets.

    """
    raise NotImplementedError()

compute abstractmethod

compute()

Compute metrics from accumulated state.

RETURNS DESCRIPTION
dict[str, float]

Dictionary mapping metric names to values.

Source code in src/formed/integrations/ml/metrics.py
127
128
129
130
131
132
133
134
135
@abc.abstractmethod
def compute(self) -> dict[str, float]:
    """Compute metrics from accumulated state.

    Returns:
        Dictionary mapping metric names to values.

    """
    raise NotImplementedError()

clone

clone()

Create a deep copy of this metric. Returns: A new instance of the metric with the same internal state.

Source code in src/formed/integrations/ml/metrics.py
150
151
152
153
154
155
def clone(self) -> Self:
    """Create a deep copy of this metric.
    Returns:
        A new instance of the metric with the same internal state.
    """
    return copy.deepcopy(self)

MultilabelAccuracy

MultilabelAccuracy(mode='subset', average='micro')

Bases: MultilabelClassificationMetric[LabelT], Generic[LabelT]

Multilabel classification accuracy metric.

Supports two modes: - "subset": Subset accuracy (exact match) - default, matches sklearn's accuracy_score - "samples": Sample-wise accuracy (label-wise, equivalent to 1 - hamming_loss)

For "samples" mode with micro averaging, computes accuracy across all label predictions. For "samples" mode with macro averaging, computes per-label accuracy then averages.

PARAMETER DESCRIPTION
mode

Accuracy calculation mode: - "subset": Counts exact matches (predicted set == target set) - "samples": Counts label-wise matches across all labels

TYPE: Literal['subset', 'samples'] DEFAULT: 'subset'

average

Averaging strategy (only applies to "samples" mode): - "micro": Overall accuracy across all samples - "macro": Average of per-label accuracies

TYPE: Literal['micro', 'macro'] DEFAULT: 'micro'

Examples:

>>> # Subset accuracy (exact match)
>>> metric = MultilabelAccuracy(mode="subset")
>>> inputs = ClassificationInput(
...     predictions=[[0, 1], [1, 2]],
...     targets=[[0, 1], [1, 2]]
... )
>>> metric.update(inputs)
>>> metric.compute()  # {"accuracy": 1.0} - both instances match exactly
>>>
>>> # Sample-wise accuracy
>>> metric = MultilabelAccuracy(mode="samples", average="micro")
>>> metric.compute()  # Counts all label predictions
Source code in src/formed/integrations/ml/metrics.py
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
def __init__(
    self,
    mode: Literal["subset", "samples"] = "subset",
    average: Literal["micro", "macro"] = "micro",
) -> None:
    self._mode = mode
    self._average = average

    if mode == "subset":
        # For subset accuracy, track exact matches
        self._exact_matches = 0
        self._total_instances = 0
    else:
        # For samples mode, track label-wise statistics
        self._correct: dict[LabelT, int] = defaultdict(int)
        self._total: dict[LabelT, int] = defaultdict(int)

Input class-attribute instance-attribute

reset

reset()
Source code in src/formed/integrations/ml/metrics.py
805
806
807
808
809
810
811
def reset(self) -> None:
    if self._mode == "subset":
        self._exact_matches = 0
        self._total_instances = 0
    else:
        self._correct = defaultdict(int)
        self._total = defaultdict(int)

update

update(inputs)
Source code in src/formed/integrations/ml/metrics.py
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
def update(self, inputs: ClassificationInput[Sequence[LabelT]]) -> None:
    predictions = inputs.predictions
    targets = inputs.targets
    assert len(predictions) == len(targets), "Predictions and targets must have the same length"

    if self._mode == "subset":
        # Subset accuracy: exact match of label sets
        for pred_labels, target_labels in zip(predictions, targets):
            if set(pred_labels) == set(target_labels):
                self._exact_matches += 1
            self._total_instances += 1
    else:
        # Sample-wise accuracy: count label matches
        # Get all labels that appear in any prediction or target
        all_labels: set[LabelT] = set()
        for pred_labels, target_labels in zip(predictions, targets):
            all_labels.update(pred_labels)
            all_labels.update(target_labels)

        # For each instance, check each label
        for pred_labels, target_labels in zip(predictions, targets):
            pred_set = set(pred_labels)
            target_set = set(target_labels)

            for label in all_labels:
                # Correct if presence matches: both have it or both don't
                if (label in pred_set) == (label in target_set):
                    self._correct[label] += 1
                self._total[label] += 1

compute

compute()
Source code in src/formed/integrations/ml/metrics.py
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
def compute(self) -> dict[str, float]:
    if self._mode == "subset":
        # Subset accuracy: fraction of exact matches
        accuracy = self._exact_matches / self._total_instances if self._total_instances > 0 else 0.0
        return {"accuracy": accuracy}
    else:
        # Sample-wise accuracy with micro or macro averaging
        if self._average == "micro":
            total_correct = sum(self._correct.values())
            total_count = sum(self._total.values())
            accuracy = total_correct / total_count if total_count > 0 else 0.0
            return {"accuracy": accuracy}
        elif self._average == "macro":
            accuracies = []
            for label in self._total.keys():
                correct = self._correct[label]
                total = self._total[label]
                accuracies.append(correct / total if total > 0 else 0.0)
            macro_accuracy = sum(accuracies) / len(accuracies) if accuracies else 0.0
            return {"accuracy": macro_accuracy}
        else:
            raise ValueError(f"Unknown average type: {self._average}")

clone

clone()

Create a deep copy of this metric. Returns: A new instance of the metric with the same internal state.

Source code in src/formed/integrations/ml/metrics.py
150
151
152
153
154
155
def clone(self) -> Self:
    """Create a deep copy of this metric.
    Returns:
        A new instance of the metric with the same internal state.
    """
    return copy.deepcopy(self)

MultilabelFBeta

MultilabelFBeta(beta=1.0, average='micro')

Bases: MultilabelClassificationMetric[LabelT], Generic[LabelT]

Multilabel F-beta score with precision and recall.

Computes label-wise TP, FP, FN across all labels that appear in predictions or targets, then calculates F-beta, precision, and recall with micro or macro averaging.

Note: This metric computes sample-wise (label-wise) statistics. For multilabel data, each label is treated independently and statistics are aggregated across all labels.

PARAMETER DESCRIPTION
beta

Weight of recall relative to precision (default: 1.0 for F1).

TYPE: float DEFAULT: 1.0

average

Averaging strategy: - "micro": Compute globally across all labels - "macro": Compute per-label then average

TYPE: Literal['micro', 'macro'] DEFAULT: 'micro'

RETURNS DESCRIPTION

Dictionary with "fbeta", "precision", and "recall" metrics.

Examples:

>>> metric = MultilabelFBeta(beta=1.0, average="micro")
>>> inputs = ClassificationInput(
...     predictions=[[0, 1], [1], [0, 2]],
...     targets=[[0, 1], [1, 2], [0]]
... )
>>> metric.update(inputs)
>>> metric.compute()
>>> # {"fbeta": ..., "precision": ..., "recall": ...}
Source code in src/formed/integrations/ml/metrics.py
899
900
901
902
903
904
905
def __init__(self, beta: float = 1.0, average: Literal["micro", "macro"] = "micro") -> None:
    self._beta = beta
    self._average = average
    self._true_positive: dict[LabelT, int] = defaultdict(int)
    self._false_positive: dict[LabelT, int] = defaultdict(int)
    self._false_negative: dict[LabelT, int] = defaultdict(int)
    self._true_negative: dict[LabelT, int] = defaultdict(int)

Input class-attribute instance-attribute

reset

reset()
Source code in src/formed/integrations/ml/metrics.py
907
908
909
910
911
def reset(self) -> None:
    self._true_positive = defaultdict(int)
    self._false_positive = defaultdict(int)
    self._false_negative = defaultdict(int)
    self._true_negative = defaultdict(int)

update

update(inputs)
Source code in src/formed/integrations/ml/metrics.py
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
def update(self, inputs: ClassificationInput[Sequence[LabelT]]) -> None:
    predictions = inputs.predictions
    targets = inputs.targets
    assert len(predictions) == len(targets), "Predictions and targets must have the same length"

    # Get all labels that appear in any prediction or target
    all_labels: set[LabelT] = set()
    for pred_labels, target_labels in zip(predictions, targets):
        all_labels.update(pred_labels)
        all_labels.update(target_labels)

    # For each instance, compute TP/FP/FN/TN for each label
    for pred_labels, target_labels in zip(predictions, targets):
        pred_set = set(pred_labels)
        target_set = set(target_labels)

        for label in all_labels:
            in_pred = label in pred_set
            in_target = label in target_set

            if in_pred and in_target:
                self._true_positive[label] += 1
            elif in_pred and not in_target:
                self._false_positive[label] += 1
            elif not in_pred and in_target:
                self._false_negative[label] += 1
            else:  # not in_pred and not in_target
                self._true_negative[label] += 1

compute

compute()
Source code in src/formed/integrations/ml/metrics.py
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
def compute(self) -> dict[str, float]:
    beta_sq = self._beta**2

    if self._average == "micro":
        total_true_positive = sum(self._true_positive.values())
        total_false_positive = sum(self._false_positive.values())
        total_false_negative = sum(self._false_negative.values())

        precision_denominator = total_true_positive + total_false_positive
        recall_denominator = total_true_positive + total_false_negative

        precision = total_true_positive / precision_denominator if precision_denominator > 0 else 0.0
        recall = total_true_positive / recall_denominator if recall_denominator > 0 else 0.0

        if precision + recall == 0:
            fbeta = 0.0
        else:
            fbeta = (1 + beta_sq) * (precision * recall) / (beta_sq * precision + recall)

        return {"fbeta": fbeta, "precision": precision, "recall": recall}
    elif self._average == "macro":
        fbetas = []
        precisions = []
        recalls = []

        # Get all labels that have any TP, FP, or FN
        all_labels = (
            set(self._true_positive.keys()).union(self._false_positive.keys()).union(self._false_negative.keys())
        )

        for label in all_labels:
            tp = self._true_positive[label]
            fp = self._false_positive[label]
            fn = self._false_negative[label]

            precision_denominator = tp + fp
            recall_denominator = tp + fn

            precision = tp / precision_denominator if precision_denominator > 0 else 0.0
            recall = tp / recall_denominator if recall_denominator > 0 else 0.0

            if precision + recall == 0:
                fbeta = 0.0
            else:
                fbeta = (1 + beta_sq) * (precision * recall) / (beta_sq * precision + recall)

            fbetas.append(fbeta)
            precisions.append(precision)
            recalls.append(recall)
        macro_fbeta = sum(fbetas) / len(fbetas) if fbetas else 0.0
        macro_precision = sum(precisions) / len(precisions) if precisions else 0.0
        macro_recall = sum(recalls) / len(recalls) if recalls else 0.0
        return {"fbeta": macro_fbeta, "precision": macro_precision, "recall": macro_recall}
    else:
        raise ValueError(f"Unknown average type: {self._average}")

clone

clone()

Create a deep copy of this metric. Returns: A new instance of the metric with the same internal state.

Source code in src/formed/integrations/ml/metrics.py
150
151
152
153
154
155
def clone(self) -> Self:
    """Create a deep copy of this metric.
    Returns:
        A new instance of the metric with the same internal state.
    """
    return copy.deepcopy(self)

RegressionInput dataclass

RegressionInput(predictions, targets)

predictions instance-attribute

predictions

targets instance-attribute

targets

RegressionMetric

Bases: BaseMetric[RegressionInput]

Input class-attribute instance-attribute

reset abstractmethod

reset()

Reset internal state for a new evaluation.

This should clear all accumulated statistics.

Source code in src/formed/integrations/ml/metrics.py
108
109
110
111
112
113
114
115
@abc.abstractmethod
def reset(self) -> None:
    """Reset internal state for a new evaluation.

    This should clear all accumulated statistics.

    """
    raise NotImplementedError()

update abstractmethod

update(inputs)

Update internal state with a batch of predictions.

PARAMETER DESCRIPTION
inputs

Batch of predictions and targets.

TYPE: _T

Source code in src/formed/integrations/ml/metrics.py
117
118
119
120
121
122
123
124
125
@abc.abstractmethod
def update(self, inputs: _T) -> None:
    """Update internal state with a batch of predictions.

    Args:
        inputs: Batch of predictions and targets.

    """
    raise NotImplementedError()

compute abstractmethod

compute()

Compute metrics from accumulated state.

RETURNS DESCRIPTION
dict[str, float]

Dictionary mapping metric names to values.

Source code in src/formed/integrations/ml/metrics.py
127
128
129
130
131
132
133
134
135
@abc.abstractmethod
def compute(self) -> dict[str, float]:
    """Compute metrics from accumulated state.

    Returns:
        Dictionary mapping metric names to values.

    """
    raise NotImplementedError()

clone

clone()

Create a deep copy of this metric. Returns: A new instance of the metric with the same internal state.

Source code in src/formed/integrations/ml/metrics.py
150
151
152
153
154
155
def clone(self) -> Self:
    """Create a deep copy of this metric.
    Returns:
        A new instance of the metric with the same internal state.
    """
    return copy.deepcopy(self)

MeanSquaredError

MeanSquaredError()

Bases: RegressionMetric

Source code in src/formed/integrations/ml/metrics.py
1012
1013
1014
def __init__(self) -> None:
    self._squared_error = 0.0
    self._count = 0

Input class-attribute instance-attribute

reset

reset()
Source code in src/formed/integrations/ml/metrics.py
1016
1017
1018
def reset(self) -> None:
    self._squared_error = 0.0
    self._count = 0

update

update(inputs)
Source code in src/formed/integrations/ml/metrics.py
1020
1021
1022
1023
1024
1025
1026
1027
def update(self, inputs: RegressionInput) -> None:
    predictions = inputs.predictions
    targets = inputs.targets
    assert len(predictions) == len(targets), "Predictions and targets must have the same length"

    for pred, target in zip(predictions, targets):
        self._squared_error += (pred - target) ** 2
        self._count += 1

compute

compute()
Source code in src/formed/integrations/ml/metrics.py
1029
1030
1031
def compute(self) -> dict[str, float]:
    mse = self._squared_error / self._count if self._count > 0 else 0.0
    return {"mean_squared_error": mse}

clone

clone()

Create a deep copy of this metric. Returns: A new instance of the metric with the same internal state.

Source code in src/formed/integrations/ml/metrics.py
150
151
152
153
154
155
def clone(self) -> Self:
    """Create a deep copy of this metric.
    Returns:
        A new instance of the metric with the same internal state.
    """
    return copy.deepcopy(self)

MeanAbsoluteError

MeanAbsoluteError()

Bases: RegressionMetric

Source code in src/formed/integrations/ml/metrics.py
1037
1038
1039
def __init__(self) -> None:
    self._absolute_error = 0.0
    self._count = 0

Input class-attribute instance-attribute

reset

reset()
Source code in src/formed/integrations/ml/metrics.py
1041
1042
1043
def reset(self) -> None:
    self._absolute_error = 0.0
    self._count = 0

update

update(inputs)
Source code in src/formed/integrations/ml/metrics.py
1045
1046
1047
1048
1049
1050
1051
1052
def update(self, inputs: RegressionInput) -> None:
    predictions = inputs.predictions
    targets = inputs.targets
    assert len(predictions) == len(targets), "Predictions and targets must have the same length"

    for pred, target in zip(predictions, targets):
        self._absolute_error += abs(pred - target)
        self._count += 1

compute

compute()
Source code in src/formed/integrations/ml/metrics.py
1054
1055
1056
def compute(self) -> dict[str, float]:
    mae = self._absolute_error / self._count if self._count > 0 else 0.0
    return {"mean_absolute_error": mae}

clone

clone()

Create a deep copy of this metric. Returns: A new instance of the metric with the same internal state.

Source code in src/formed/integrations/ml/metrics.py
150
151
152
153
154
155
def clone(self) -> Self:
    """Create a deep copy of this metric.
    Returns:
        A new instance of the metric with the same internal state.
    """
    return copy.deepcopy(self)

RankingInput dataclass

RankingInput(predictions, targets)

Bases: Generic[LabelT]

predictions instance-attribute

predictions

targets instance-attribute

targets

RankingMetric

Bases: BaseMetric[RankingInput[LabelT]], Generic[LabelT]

Input class-attribute instance-attribute

Input = RankingInput

reset abstractmethod

reset()

Reset internal state for a new evaluation.

This should clear all accumulated statistics.

Source code in src/formed/integrations/ml/metrics.py
108
109
110
111
112
113
114
115
@abc.abstractmethod
def reset(self) -> None:
    """Reset internal state for a new evaluation.

    This should clear all accumulated statistics.

    """
    raise NotImplementedError()

update abstractmethod

update(inputs)

Update internal state with a batch of predictions.

PARAMETER DESCRIPTION
inputs

Batch of predictions and targets.

TYPE: _T

Source code in src/formed/integrations/ml/metrics.py
117
118
119
120
121
122
123
124
125
@abc.abstractmethod
def update(self, inputs: _T) -> None:
    """Update internal state with a batch of predictions.

    Args:
        inputs: Batch of predictions and targets.

    """
    raise NotImplementedError()

compute abstractmethod

compute()

Compute metrics from accumulated state.

RETURNS DESCRIPTION
dict[str, float]

Dictionary mapping metric names to values.

Source code in src/formed/integrations/ml/metrics.py
127
128
129
130
131
132
133
134
135
@abc.abstractmethod
def compute(self) -> dict[str, float]:
    """Compute metrics from accumulated state.

    Returns:
        Dictionary mapping metric names to values.

    """
    raise NotImplementedError()

clone

clone()

Create a deep copy of this metric. Returns: A new instance of the metric with the same internal state.

Source code in src/formed/integrations/ml/metrics.py
150
151
152
153
154
155
def clone(self) -> Self:
    """Create a deep copy of this metric.
    Returns:
        A new instance of the metric with the same internal state.
    """
    return copy.deepcopy(self)

MeanAveragePrecision

MeanAveragePrecision()

Bases: RankingMetric[LabelT], Generic[LabelT]

Source code in src/formed/integrations/ml/metrics.py
1072
1073
def __init__(self) -> None:
    self._average_precisions: list[float] = []

Input class-attribute instance-attribute

Input = RankingInput

reset

reset()
Source code in src/formed/integrations/ml/metrics.py
1075
1076
def reset(self) -> None:
    self._average_precisions = []

update

update(inputs)
Source code in src/formed/integrations/ml/metrics.py
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
def update(self, inputs: RankingInput[LabelT]) -> None:
    predictions = inputs.predictions
    targets = inputs.targets
    assert len(predictions) == len(targets), "Predictions and targets must have the same length"

    for pred_scores, target_labels in zip(predictions, targets):
        sorted_labels = sorted(pred_scores.keys(), key=lambda x: pred_scores[x], reverse=True)
        relevant_set = set(target_labels)

        num_relevant = 0
        precision_sum = 0.0

        for rank, label in enumerate(sorted_labels, start=1):
            if label in relevant_set:
                num_relevant += 1
                precision_sum += num_relevant / rank

        average_precision = precision_sum / len(relevant_set) if relevant_set else 0.0
        self._average_precisions.append(average_precision)

compute

compute()
Source code in src/formed/integrations/ml/metrics.py
1098
1099
1100
def compute(self) -> dict[str, float]:
    mean_ap = sum(self._average_precisions) / len(self._average_precisions) if self._average_precisions else 0.0
    return {"mean_average_precision": mean_ap}

clone

clone()

Create a deep copy of this metric. Returns: A new instance of the metric with the same internal state.

Source code in src/formed/integrations/ml/metrics.py
150
151
152
153
154
155
def clone(self) -> Self:
    """Create a deep copy of this metric.
    Returns:
        A new instance of the metric with the same internal state.
    """
    return copy.deepcopy(self)

NDCG

NDCG(k=10)

Bases: RankingMetric[LabelT], Generic[LabelT]

Source code in src/formed/integrations/ml/metrics.py
1106
1107
1108
def __init__(self, k: int = 10) -> None:
    self._k = k
    self._ndcgs: list[float] = []

Input class-attribute instance-attribute

Input = RankingInput

reset

reset()
Source code in src/formed/integrations/ml/metrics.py
1110
1111
def reset(self) -> None:
    self._ndcgs = []

update

update(inputs)
Source code in src/formed/integrations/ml/metrics.py
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
def update(self, inputs: RankingInput[LabelT]) -> None:
    predictions = inputs.predictions
    targets = inputs.targets
    assert len(predictions) == len(targets), "Predictions and targets must have the same length"

    for pred_scores, target_labels in zip(predictions, targets):
        sorted_labels = sorted(pred_scores.keys(), key=lambda x: pred_scores[x], reverse=True)
        relevant_set = set(target_labels)

        dcg = 0.0
        for rank, label in enumerate(sorted_labels[: self._k], start=1):
            if label in relevant_set:
                dcg += 1 / math.log2(rank + 1)

        ideal_dcg = 0.0
        for rank in range(1, min(len(relevant_set), self._k) + 1):
            ideal_dcg += 1 / math.log2(rank + 1)

        ndcg = dcg / ideal_dcg if ideal_dcg > 0 else 0.0
        self._ndcgs.append(ndcg)

compute

compute()
Source code in src/formed/integrations/ml/metrics.py
1134
1135
1136
def compute(self) -> dict[str, float]:
    mean_ndcg = sum(self._ndcgs) / len(self._ndcgs) if self._ndcgs else 0.0
    return {"ndcg": mean_ndcg}

clone

clone()

Create a deep copy of this metric. Returns: A new instance of the metric with the same internal state.

Source code in src/formed/integrations/ml/metrics.py
150
151
152
153
154
155
def clone(self) -> Self:
    """Create a deep copy of this metric.
    Returns:
        A new instance of the metric with the same internal state.
    """
    return copy.deepcopy(self)

formed.integrations.ml.transforms.base

Base classes and utilities for data transformations.

This module provides the core infrastructure for building type-safe, composable data transformations in machine learning pipelines. It supports both single-instance and batched processing with strong type guarantees.

Key Components
  • BaseTransform: Abstract base class for all transformations
  • DataModule: Composable data transformation container
  • Extra: Descriptor for optional fields (e.g., labels in test data)
  • Param: Descriptor for non-transformed parameters
  • register_dataclass: Function to register dataclasses with JAX pytree
Design Patterns
  • Descriptor protocol for field access control
  • Generic type parameters for type safety
  • Mode-based behavior (AsInstance, AsBatch, AsConverter)
  • Automatic JAX pytree registration for compatibility with jax.jit/jax.vmap

Examples:

>>> from formed.integrations.ml import DataModule, TensorTransform, LabelIndexer, Extra
>>>
>>> class MyDataModule(DataModule[DataModuleModeT, dict, ...]):
...     features: TensorTransform
...     label: Extra[LabelIndexer] = Extra.default()
>>>
>>> dm = MyDataModule(features=TensorTransform(), label=LabelIndexer())
>>> with dm.train():
...     instance = dm.instance({"features": [1.0, 2.0], "label": "positive"})
>>> batch = dm.batch([instance1, instance2, instance3])
Note

If JAX is installed, all DataModule instances are automatically registered as JAX pytrees for compatibility with JAX transformations.

logger module-attribute

logger = getLogger(__name__)

Extra

Extra(*args, **kwargs)

Bases: Generic[_BaseTransformT_co]

Descriptor marker for optional transformation fields in DataModule.

Extra fields are optional and can be None, which is useful for fields that may not be present in all data (e.g., labels in test/inference data). When accessed, Extra fields return the transformed value in instance/batch mode, or the transform itself in converter mode.

CLASS TYPE PARAMETER DESCRIPTION
_BaseTransformT_co

The transform type (covariant).

Examples:

>>> class MyDataModule(DataModule[...]):
...     text: Tokenizer
...     label: Extra[LabelIndexer] = Extra.default()  # Optional field
>>>
>>> # Training mode with labels
>>> train_dm = MyDataModule(text=Tokenizer(), label=LabelIndexer())
>>> with train_dm.train():
...     instance = train_dm.instance({"text": "hello", "label": "positive"})
>>> print(instance.label)  # Returns the transformed label index
>>>
>>> # Inference mode without labels
>>> test_dm = MyDataModule(text=Tokenizer(), label=None)
>>> test_instance = test_dm.instance({"text": "hello"})
>>> print(test_instance.label)  # Returns None
Note

Extra is a marker class and cannot be instantiated directly. Use Extra.default() to provide a default value.

Source code in src/formed/integrations/ml/transforms/base.py
258
259
def __init__(self, *args: Any, **kwargs: Any) -> None:
    raise TypeError("Extra is a marker class and cannot be instantiated directly")

default classmethod

default(default=None)

Create a default Extra field with an optional default transform.

PARAMETER DESCRIPTION
default

Optional default transform to use if not specified.

TYPE: _BaseTransformT | None DEFAULT: None

RETURNS DESCRIPTION
Extra[_BaseTransformT]

An Extra field with the specified default.

Source code in src/formed/integrations/ml/transforms/base.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
@classmethod
def default(
    cls: type["Extra[_BaseTransformT]"],
    default: _BaseTransformT | None = None,
) -> "Extra[_BaseTransformT]":
    """Create a default Extra field with an optional default transform.

    Args:
        default: Optional default transform to use if not specified.

    Returns:
        An Extra field with the specified default.

    """
    return cast(Extra[_BaseTransformT], default)

default_factory classmethod

default_factory(factory)

Create a factory for an Extra field with an optional default transform.

PARAMETER DESCRIPTION
factory

A callable that returns the default transform.

TYPE: Callable[[], _BaseTransformT | None]

RETURNS DESCRIPTION
Callable[[], Extra[_BaseTransformT]]

A factory callable for creating Extra fields.

Source code in src/formed/integrations/ml/transforms/base.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
@classmethod
def default_factory(
    cls: type["Extra[_BaseTransformT]"],
    factory: Callable[[], _BaseTransformT | None],
) -> Callable[[], "Extra[_BaseTransformT]"]:
    """Create a factory for an Extra field with an optional default transform.

    Args:
        factory: A callable that returns the default transform.

    Returns:
        A factory callable for creating Extra fields.

    """
    return cast(Callable[[], Extra[_BaseTransformT]], factory)

Param

Param()

Bases: Generic[_T_co]

Descriptor marker for non-transformed parameter fields in DataModule.

Param fields represent parameters that pass through unchanged during instance/batch conversion. They are not transformed but remain accessible in all modes. This is useful for hyperparameters, configuration values, or other metadata that should be available but not processed.

CLASS TYPE PARAMETER DESCRIPTION
_T_co

The parameter type (covariant).

Examples:

>>> class MyDataModule(DataModule[...]):
...     text: Tokenizer
...     max_length: Param[int] = Param.default(128)
...     temperature: Param[float] = Param.default(1.0)
>>>
>>> dm = MyDataModule(text=Tokenizer(), max_length=256, temperature=0.8)
>>> instance = dm.instance({"text": "hello"})
>>> print(instance.max_length)  # Returns 256 (unchanged)
>>> batch = dm.batch([instance1, instance2])
>>> print(batch.max_length)  # Still returns 256
Note

Param is a marker class and cannot be instantiated directly. Use Param.default() or Param.default_factory() to provide defaults.

Source code in src/formed/integrations/ml/transforms/base.py
369
370
def __init__(self) -> None:
    raise TypeError("Param is a marker class and cannot be instantiated directly")

default classmethod

default(default)

Create a Param field with a default value.

PARAMETER DESCRIPTION
default

The default value for this parameter.

TYPE: _T

RETURNS DESCRIPTION
Param[_T]

A Param field with the specified default.

Source code in src/formed/integrations/ml/transforms/base.py
328
329
330
331
332
333
334
335
336
337
338
339
@classmethod
def default(cls: type["Param[_T]"], default: _T) -> "Param[_T]":
    """Create a Param field with a default value.

    Args:
        default: The default value for this parameter.

    Returns:
        A Param field with the specified default.

    """
    return cast(Param[_T], default)

cast classmethod

cast(value)

Wrap a value as a Param field.

PARAMETER DESCRIPTION
value

The value to wrap as a Param.

TYPE: _T

Returns: A Param field wrapping the given value.

Source code in src/formed/integrations/ml/transforms/base.py
341
342
343
344
345
346
347
348
349
350
351
@classmethod
def cast(cls: type["Param[_T]"], value: _T) -> "Param[_T]":
    """Wrap a value as a Param field.

    Args:
        value: The value to wrap as a Param.
    Returns:
        A Param field wrapping the given value.

    """
    return cast(Param[_T], value)

default_factory classmethod

default_factory(factory)

Create a Param field with a default factory function.

PARAMETER DESCRIPTION
factory

A callable that returns the default value.

TYPE: Callable[[], _T]

RETURNS DESCRIPTION
Callable[[], Param[_T]]

A factory callable for creating Param fields.

Source code in src/formed/integrations/ml/transforms/base.py
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
@classmethod
def default_factory(
    cls: type["Param[_T]"],
    factory: Callable[[], _T],
) -> Callable[[], "Param[_T]"]:
    """Create a Param field with a default factory function.

    Args:
        factory: A callable that returns the default value.

    Returns:
        A factory callable for creating Param fields.

    """
    return cast(Callable[[], Param[_T]], factory)

BaseTransformMeta

Bases: ABCMeta

BaseTransform

Bases: Registrable, Generic[_S, _T, InstanceT_co, BatchT_co], ABC

Abstract base class for data transformations.

BaseTransform provides a two-stage transformation pipeline: 1. Instance transformation: Convert raw data to per-instance representation 2. Batch transformation: Collate multiple instances into batched tensors

The class uses descriptors for flexible field access and supports training/inference modes for stateful transformations (e.g., vocabulary building).

CLASS TYPE PARAMETER DESCRIPTION
_S

Source data type before accessor is applied

_T

Target data type after accessor is applied

InstanceT_co

Instance representation type (covariant)

BatchT_co

Batch representation type (covariant)

ATTRIBUTE DESCRIPTION
accessor

Optional accessor to extract the relevant field from input data. Can be a string (attribute/key name) or a callable.

TYPE: str | Callable[[_S], _T] | None

Class Attributes

is_static: If True, indicates the batched value is static for JAX. process_parent: If True, the accessor receives the entire parent object.

Abstract Methods

instance: Transform a single data point to its instance representation. batch: Collate a sequence of instances into a batched representation.

Examples:

>>> class LowercaseTransform(BaseTransform[dict, str, str, list[str]]):
...     def instance(self, text: str) -> str:
...         return text.lower()
...
...     def batch(self, instances: Sequence[str]) -> list[str]:
...         return list(instances)
>>>
>>> transform = LowercaseTransform(accessor="text")
>>> instance = transform({"text": "HELLO"})  # Returns "hello"
>>> batch = transform.batch(["hello", "world"])  # Returns ["hello", "world"]
Note
  • Subclasses are automatically converted to dataclasses via metaclass.
  • Use the train() context manager for stateful transformations.
  • Supports saving/loading with cloudpickle for persistence.

accessor class-attribute instance-attribute

accessor = None

instance abstractmethod

instance(obj)

Transform a single data point to its instance representation.

PARAMETER DESCRIPTION
obj

The input data after accessor extraction.

TYPE: _T

RETURNS DESCRIPTION
InstanceT_co

The transformed instance representation.

Note

This method is called for each individual data point.

Source code in src/formed/integrations/ml/transforms/base.py
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
@abc.abstractmethod
def instance(self, obj: _T, /) -> InstanceT_co:
    """Transform a single data point to its instance representation.

    Args:
        obj: The input data after accessor extraction.

    Returns:
        The transformed instance representation.

    Note:
        This method is called for each individual data point.

    """
    raise NotImplementedError("Subclasses must implement this method")

batch abstractmethod

batch(batch)

Collate multiple instances into a batched representation.

PARAMETER DESCRIPTION
batch

A sequence of instance representations from instance().

TYPE: Sequence[InstanceT_co]

RETURNS DESCRIPTION
BatchT_co

The batched representation, typically as tensors or arrays.

Note

This method should handle padding, stacking, or other batching logic.

Source code in src/formed/integrations/ml/transforms/base.py
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
@abc.abstractmethod
def batch(self, batch: Sequence[InstanceT_co], /) -> BatchT_co:
    """Collate multiple instances into a batched representation.

    Args:
        batch: A sequence of instance representations from `instance()`.

    Returns:
        The batched representation, typically as tensors or arrays.

    Note:
        This method should handle padding, stacking, or other batching logic.

    """
    raise NotImplementedError("Subclasses must implement this method")

train

train()

Context manager to enable training mode for stateful transformations.

In training mode, transforms can build state (e.g., vocabularies, statistics) from the training data. Hooks _on_start_training() and _on_end_training() are called at the beginning and end of the training context.

YIELDS DESCRIPTION
None

None

Examples:

>>> indexer = TokenSequenceIndexer()
>>> with indexer.train():
...     # Build vocabulary from training data
...     tokens1 = indexer.instance(["hello", "world"])
...     tokens2 = indexer.instance(["hello", "there"])
>>> # Vocabulary is now frozen, use for inference
>>> test_tokens = indexer.instance(["hello", "unknown"])
Note

Training mode is reentrant but nested calls won't trigger hooks again.

Source code in src/formed/integrations/ml/transforms/base.py
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
@contextmanager
def train(self) -> Iterator[None]:
    """Context manager to enable training mode for stateful transformations.

    In training mode, transforms can build state (e.g., vocabularies, statistics)
    from the training data. Hooks `_on_start_training()` and `_on_end_training()`
    are called at the beginning and end of the training context.

    Yields:
        None

    Examples:
        >>> indexer = TokenSequenceIndexer()
        >>> with indexer.train():
        ...     # Build vocabulary from training data
        ...     tokens1 = indexer.instance(["hello", "world"])
        ...     tokens2 = indexer.instance(["hello", "there"])
        >>> # Vocabulary is now frozen, use for inference
        >>> test_tokens = indexer.instance(["hello", "unknown"])

    Note:
        Training mode is reentrant but nested calls won't trigger hooks again.

    """
    original = self._training
    self._training = True
    try:
        if not original:
            self._on_start_training()
        yield
        if not original:
            self._on_end_training()
    finally:
        self._training = original

save

save(directory)

Save the transform to a directory using cloudpickle.

PARAMETER DESCRIPTION
directory

Directory path to save the transform.

TYPE: str | PathLike

Note

The transform is saved as 'transform.pkl' in the specified directory. cloudpickle is used to handle complex objects like lambdas and closures.

Source code in src/formed/integrations/ml/transforms/base.py
631
632
633
634
635
636
637
638
639
640
641
642
643
644
def save(self, directory: str | PathLike) -> None:
    """Save the transform to a directory using cloudpickle.

    Args:
        directory: Directory path to save the transform.

    Note:
        The transform is saved as 'transform.pkl' in the specified directory.
        cloudpickle is used to handle complex objects like lambdas and closures.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("wb") as f:
        cloudpickle.dump(self, f)

load classmethod

load(directory)

Load a transform from a directory.

PARAMETER DESCRIPTION
directory

Directory path containing the saved transform.

TYPE: str | PathLike

RETURNS DESCRIPTION
Self

The loaded transform instance.

RAISES DESCRIPTION
TypeError

If the loaded object is not an instance of this class.

Note

Expects a 'transform.pkl' file in the specified directory.

Source code in src/formed/integrations/ml/transforms/base.py
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
@classmethod
def load(cls, directory: str | PathLike) -> Self:
    """Load a transform from a directory.

    Args:
        directory: Directory path containing the saved transform.

    Returns:
        The loaded transform instance.

    Raises:
        TypeError: If the loaded object is not an instance of this class.

    Note:
        Expects a 'transform.pkl' file in the specified directory.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("rb") as f:
        obj = cloudpickle.load(f)
    if not isinstance(obj, cls):
        raise TypeError(f"Loaded object is not an instance of {cls.__name__}")
    return obj

DataModule

Bases: BaseTransform[_T, _T, _InstanceT, _BatchT], Generic[_DataModuleModeT_co, _T, _InstanceT, _BatchT]

Composable container for multiple data transformations with mode-based behavior.

DataModule orchestrates multiple BaseTransform fields and switches between three modes: - AsConverter: Configuration mode, holds transform logic - AsInstance: Single data point after per-instance transformation - AsBatch: Multiple instances collated into batched tensors

This enables a single class definition to represent raw data, transformed instances, and batched tensors with full type safety.

CLASS TYPE PARAMETER DESCRIPTION
_DataModuleModeT_co

Current mode (AsConverter, AsInstance, or AsBatch)

_T

Input data type

_InstanceT

Instance mode type (self when mode=AsInstance)

_BatchT

Batch mode type (self when mode=AsBatch)

Field Types
  • Regular fields: BaseTransform subclasses that transform data
  • Extra fields: Optional transforms (e.g., labels for test data)
  • Param fields: Non-transformed parameters that pass through unchanged

Examples:

>>> @dataclasses.dataclass
... class TextExamples:
...     text: str
...     label: Optional[str] = None
>>>
>>> class TextDataModule(DataModule[DataModuleModeT, TextExample, ...]):
...     text: Tokenizer
...     label: Extra[LabelIndexer] = Extra.default()
>>>
>>> # Create converter (configuration)
>>> dm = TextDataModule(
...     text=Tokenizer(surfaces=TokenSequenceIndexer()),
...     label=LabelIndexer()
... )
>>>
>>> # Training: build vocabularies
>>> with dm.train():
...     train_instances = [
...         dm.instance(TextExample("hello world", "positive"))
...         for example in train_data
...     ]
>>>
>>> # Create batches
>>> batch = dm.batch(train_instances[:32])
>>> print(batch.text.surfaces.ids.shape)  # (32, max_length)
>>> print(batch.label.shape)  # (32,)
>>>
>>> # Inference without labels
>>> test_dm = TextDataModule(text=dm.text, label=None)
>>> test_instance = test_dm.instance(TextExample("test sentence"))
>>> print(test_instance.label)  # None
Note
  • Automatically registered as JAX pytree if JAX is available
  • Mode transitions are enforced by type system
  • Fields are descriptors with mode-dependent behavior

accessor class-attribute instance-attribute

accessor = None

train

train()

Context manager to enable training mode for all field transforms.

This propagates training mode to all BaseTransform fields, allowing them to build state (e.g., vocabularies) from training data.

YIELDS DESCRIPTION
None

None

Examples:

>>> dm = TextDataModule(text=Tokenizer(), label=LabelIndexer())
>>> with dm.train():
...     instances = [dm.instance(example) for example in train_data]
>>> # Vocabularies are now built and frozen
Note

Can only be called in AsConverter mode.

Source code in src/formed/integrations/ml/transforms/base.py
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
@contextmanager
def train(self) -> Iterator[None]:
    """Context manager to enable training mode for all field transforms.

    This propagates training mode to all BaseTransform fields, allowing them
    to build state (e.g., vocabularies) from training data.

    Yields:
        None

    Examples:
        >>> dm = TextDataModule(text=Tokenizer(), label=LabelIndexer())
        >>> with dm.train():
        ...     instances = [dm.instance(example) for example in train_data]
        >>> # Vocabularies are now built and frozen

    Note:
        Can only be called in AsConverter mode.

    """
    assert self.__mode__ in (None, DataModuleMode.AS_CONVERTER), (
        "DataModule must be in converter mode to enter training mode"
    )
    with ExitStack() as stack:
        for transform in self.__field_transforms__.values():
            stack.enter_context(transform.train())
        yield

instance

instance(obj)

Transform raw data into an instance representation.

Applies all field transforms to create a DataModule in AsInstance mode. Each transform field processes the corresponding data attribute/key.

PARAMETER DESCRIPTION
obj

The raw input data object.

TYPE: _T

RETURNS DESCRIPTION
_InstanceT

A DataModule in AsInstance mode with transformed fields.

Examples:

>>> dm = TextDataModule(text=Tokenizer(), label=LabelIndexer())
>>> instance = dm.instance({"text": "hello world", "label": "positive"})
>>> print(instance.text.surfaces)  # Tokenized text
>>> print(instance.label)  # Label index
Note
  • Can only be called in AsConverter mode
  • Returns a new DataModule with mode=AsInstance
  • Extra fields can be None if data is missing
Source code in src/formed/integrations/ml/transforms/base.py
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
def instance(self: "DataModule[AsConverter]", obj: _T, /) -> _InstanceT:
    """Transform raw data into an instance representation.

    Applies all field transforms to create a DataModule in AsInstance mode.
    Each transform field processes the corresponding data attribute/key.

    Args:
        obj: The raw input data object.

    Returns:
        A DataModule in AsInstance mode with transformed fields.

    Examples:
        >>> dm = TextDataModule(text=Tokenizer(), label=LabelIndexer())
        >>> instance = dm.instance({"text": "hello world", "label": "positive"})
        >>> print(instance.text.surfaces)  # Tokenized text
        >>> print(instance.label)  # Label index

    Note:
        - Can only be called in `AsConverter` mode
        - Returns a new `DataModule` with `mode=AsInstance`
        - Extra fields can be `None` if data is missing

    """
    assert self.__mode__ in (None, DataModuleMode.AS_CONVERTER), (
        "DataModule must be in converter mode to create an instance"
    )

    fields = {}
    for name, transform in self.__field_transforms__.items():
        fields[name] = transform(obj)
    for name, field in self.__class__.__get_param_fields__().items():
        if (
            name not in fields
            and field.default is not dataclasses.MISSING
            and field.default_factory is dataclasses.MISSING
        ):
            fields[name] = _UNAVAILABLE

    instance = cast(_InstanceT, dataclasses.replace(self, **fields))
    setattr(instance, "__mode__", DataModuleMode.AS_INSTANCE)

    return instance

batch

batch(instances)

Collate multiple instances into a batched representation.

Takes a sequence of raw data or instances and creates a DataModule in AsBatch mode. Each transform field's batch() method is called to collate the corresponding field values.

PARAMETER DESCRIPTION
instances

Sequence of raw data or DataModule instances.

TYPE: Sequence[_T | _InstanceT]

RETURNS DESCRIPTION
_BatchT

A DataModule in AsBatch mode with batched tensor fields.

Examples:

>>> dm = TextDataModule(text=Tokenizer(), label=LabelIndexer())
>>> instances = [dm.instance(ex) for ex in examples]
>>> batch = dm.batch(instances)
>>> print(batch.text.surfaces.ids.shape)  # (batch_size, seq_length)
>>> print(batch.label.shape)  # (batch_size,)
>>> print(len(batch))  # batch_size
Note
  • Can only be called in AsConverter mode
  • Automatically converts raw data to instances if needed
  • Returns a new DataModule with mode=AsBatch
  • Extra fields are None if all instances have None for that field
Source code in src/formed/integrations/ml/transforms/base.py
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
def batch(self: "DataModule[AsConverter]", instances: Sequence[_T | _InstanceT]) -> _BatchT:
    """Collate multiple instances into a batched representation.

    Takes a sequence of raw data or instances and creates a `DataModule` in
    `AsBatch` mode. Each transform field's `batch()` method is called to collate
    the corresponding field values.

    Args:
        instances: Sequence of raw data or `DataModule` instances.

    Returns:
        A `DataModule` in `AsBatch` mode with batched tensor fields.

    Examples:
        >>> dm = TextDataModule(text=Tokenizer(), label=LabelIndexer())
        >>> instances = [dm.instance(ex) for ex in examples]
        >>> batch = dm.batch(instances)
        >>> print(batch.text.surfaces.ids.shape)  # (batch_size, seq_length)
        >>> print(batch.label.shape)  # (batch_size,)
        >>> print(len(batch))  # batch_size

    Note:
        - Can only be called in `AsConverter` mode
        - Automatically converts raw data to instances if needed
        - Returns a new `DataModule` with `mode=AsBatch`
        - Extra fields are `None` if all instances have None for that field

    """
    assert self.__mode__ in (None, DataModuleMode.AS_CONVERTER), (
        "DataModule must be in converter mode to create a batch"
    )

    instances = [item if isinstance(item, DataModule) else self.instance(item) for item in instances]
    fields = {}
    for name, transform in self.__field_transforms__.items():
        can_be_optional = name in self.__class__.__get_extra_fields__()
        values = [getattr(instance, name) for instance in instances]
        if can_be_optional and all(value is None for value in values):
            fields[name] = None
        else:
            fields[name] = transform.batch(values)
    for name in self.__class__.__get_param_fields__().keys():
        if name not in fields:
            fields[name] = _UNAVAILABLE

    batch = cast(_BatchT, dataclasses.replace(self, **fields))
    setattr(batch, "__mode__", DataModuleMode.AS_BATCH)

    batch._batch_size = len(instances)
    return batch

save

save(directory)

Save the transform to a directory using cloudpickle.

PARAMETER DESCRIPTION
directory

Directory path to save the transform.

TYPE: str | PathLike

Note

The transform is saved as 'transform.pkl' in the specified directory. cloudpickle is used to handle complex objects like lambdas and closures.

Source code in src/formed/integrations/ml/transforms/base.py
631
632
633
634
635
636
637
638
639
640
641
642
643
644
def save(self, directory: str | PathLike) -> None:
    """Save the transform to a directory using cloudpickle.

    Args:
        directory: Directory path to save the transform.

    Note:
        The transform is saved as 'transform.pkl' in the specified directory.
        cloudpickle is used to handle complex objects like lambdas and closures.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("wb") as f:
        cloudpickle.dump(self, f)

load classmethod

load(directory)

Load a transform from a directory.

PARAMETER DESCRIPTION
directory

Directory path containing the saved transform.

TYPE: str | PathLike

RETURNS DESCRIPTION
Self

The loaded transform instance.

RAISES DESCRIPTION
TypeError

If the loaded object is not an instance of this class.

Note

Expects a 'transform.pkl' file in the specified directory.

Source code in src/formed/integrations/ml/transforms/base.py
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
@classmethod
def load(cls, directory: str | PathLike) -> Self:
    """Load a transform from a directory.

    Args:
        directory: Directory path containing the saved transform.

    Returns:
        The loaded transform instance.

    Raises:
        TypeError: If the loaded object is not an instance of this class.

    Note:
        Expects a 'transform.pkl' file in the specified directory.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("rb") as f:
        obj = cloudpickle.load(f)
    if not isinstance(obj, cls):
        raise TypeError(f"Loaded object is not an instance of {cls.__name__}")
    return obj

register_dataclass

register_dataclass(cls)

Register a dataclass with JAX pytree if JAX is available.

This function automatically registers dataclasses as JAX pytrees, enabling them to be used with JAX transformations like jax.jit, jax.vmap, and jax.grad. It distinguishes between data fields, metadata fields, and fields to drop based on field metadata and the Param/Extra field markers.

PARAMETER DESCRIPTION
cls

A dataclass type to register.

TYPE: _TypeT

RETURNS DESCRIPTION
_TypeT

The same class, now registered as a JAX pytree (if JAX is installed).

Note
  • If JAX is not installed, this function does nothing.
  • Fields marked with JAX_STATIC_FIELD metadata become meta_fields.
  • Fields with init=False and not marked as Param are dropped.
  • Registration is idempotent; registering twice has no effect.
  • If the class is a DataModule, recursively registers nested dataclasses.

Examples:

>>> @dataclasses.dataclass
... class MyData:
...     values: list[float]
...     metadata: str = dataclasses.field(metadata={JAX_STATIC_FIELD: True})
>>> register_dataclass(MyData)
>>> # Now MyData can be used with JAX transformations
Source code in src/formed/integrations/ml/transforms/base.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def register_dataclass(cls: _TypeT) -> _TypeT:
    """Register a dataclass with JAX pytree if JAX is available.

    This function automatically registers dataclasses as JAX pytrees, enabling
    them to be used with JAX transformations like jax.jit, jax.vmap, and jax.grad.
    It distinguishes between data fields, metadata fields, and fields to drop based
    on field metadata and the Param/Extra field markers.

    Args:
        cls: A dataclass type to register.

    Returns:
        The same class, now registered as a JAX pytree (if JAX is installed).

    Note:
        - If JAX is not installed, this function does nothing.
        - Fields marked with JAX_STATIC_FIELD metadata become meta_fields.
        - Fields with init=False and not marked as Param are dropped.
        - Registration is idempotent; registering twice has no effect.
        - If the class is a DataModule, recursively registers nested dataclasses.

    Examples:
        >>> @dataclasses.dataclass
        ... class MyData:
        ...     values: list[float]
        ...     metadata: str = dataclasses.field(metadata={JAX_STATIC_FIELD: True})
        >>> register_dataclass(MyData)
        >>> # Now MyData can be used with JAX transformations

    """
    if cls in _DATACLASS_REGISTRY:
        return cls

    _DATACLASS_REGISTRY.add(cls)

    with suppress(ImportError):
        import jax

        def _is_static_field(field: dataclasses.Field) -> bool:
            if field.metadata.get(JAX_STATIC_FIELD, False):
                return True
            field_class = _find_dataclass_field(field.type)
            if field_class is not None:
                return getattr(field_class, "__is_static__", False)
            return False

        if getattr(cls, "__is_datamodule__", False):
            for field in dataclasses.fields(cls):
                field_class = _find_dataclass_field(field.type)
                if field_class is not None:
                    register_dataclass(field_class)

        drop_fields = [f.name for f in dataclasses.fields(cls) if not f.init and not _is_param_field(f.type)]
        data_fields = [f.name for f in dataclasses.fields(cls) if not _is_static_field(f) and f.name not in drop_fields]
        meta_fields = [f.name for f in dataclasses.fields(cls) if _is_static_field(f) and f.name not in drop_fields]

        try:
            jax.tree_util.register_dataclass(
                cls,
                data_fields=data_fields,
                meta_fields=meta_fields,
                drop_fields=drop_fields,
            )
        except ValueError as error:
            if str(error.args[0]).startswith("Duplicate custom dataclass"):
                pass
            else:
                raise

    return cls

formed.integrations.ml.transforms.basic

Basic data transformations for common machine learning tasks.

This module provides fundamental transform classes for handling common data types in machine learning pipelines, including labels, scalars, tensors, and metadata.

Available Transforms
  • MetadataTransform: Pass-through transform for metadata (e.g., IDs, names)
  • LabelIndexer: Map labels to integer indices with vocabulary building
  • ScalarTransform: Convert scalar values to numpy arrays
  • TensorTransform: Convert numpy arrays to batched tensors

Examples:

>>> from formed.integrations.ml import LabelIndexer, ScalarTransform
>>>
>>> # Label indexing with vocabulary building
>>> label_indexer = LabelIndexer()
>>> with label_indexer.train():
...     idx1 = label_indexer.instance("positive")  # Returns 0
...     idx2 = label_indexer.instance("negative")  # Returns 1
>>> batch = label_indexer.batch([0, 1, 0])  # np.array([0, 1, 0])
>>>
>>> # Scalar to tensor
>>> scalar_transform = ScalarTransform()
>>> values = [1.5, 2.3, 4.1]
>>> batch = scalar_transform.batch(values)  # np.array([1.5, 2.3, 4.1])

logger module-attribute

logger = getLogger(__name__)

MetadataTransform

Bases: Generic[_S, _T], BaseTransform[_S, _T, _T, Sequence[_T]]

Pass-through transform for metadata fields.

MetadataTransform does not modify data during instance transformation and simply collects values into a list during batching. This is useful for metadata like IDs, filenames, or other non-numerical information that should be preserved but not transformed into tensors.

CLASS TYPE PARAMETER DESCRIPTION
_S

Source data type before accessor

_T

Value type (same as instance and element of batch)

Examples:

>>> transform = MetadataTransform(accessor="id")
>>> instance = transform({"id": "example_001"})  # Returns "example_001"
>>> batch = transform.batch(["example_001", "example_002", "example_003"])
>>> print(batch)  # ["example_001", "example_002", "example_003"]
Note

This transform is stateless and does not require training.

accessor class-attribute instance-attribute

accessor = None

instance

instance(value)
Source code in src/formed/integrations/ml/transforms/basic.py
78
79
def instance(self, value: _T, /) -> _T:
    return value

batch

batch(batch)
Source code in src/formed/integrations/ml/transforms/basic.py
81
82
def batch(self, batch: Sequence[_T], /) -> Sequence[_T]:
    return list(batch)

train

train()

Context manager to enable training mode for stateful transformations.

In training mode, transforms can build state (e.g., vocabularies, statistics) from the training data. Hooks _on_start_training() and _on_end_training() are called at the beginning and end of the training context.

YIELDS DESCRIPTION
None

None

Examples:

>>> indexer = TokenSequenceIndexer()
>>> with indexer.train():
...     # Build vocabulary from training data
...     tokens1 = indexer.instance(["hello", "world"])
...     tokens2 = indexer.instance(["hello", "there"])
>>> # Vocabulary is now frozen, use for inference
>>> test_tokens = indexer.instance(["hello", "unknown"])
Note

Training mode is reentrant but nested calls won't trigger hooks again.

Source code in src/formed/integrations/ml/transforms/base.py
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
@contextmanager
def train(self) -> Iterator[None]:
    """Context manager to enable training mode for stateful transformations.

    In training mode, transforms can build state (e.g., vocabularies, statistics)
    from the training data. Hooks `_on_start_training()` and `_on_end_training()`
    are called at the beginning and end of the training context.

    Yields:
        None

    Examples:
        >>> indexer = TokenSequenceIndexer()
        >>> with indexer.train():
        ...     # Build vocabulary from training data
        ...     tokens1 = indexer.instance(["hello", "world"])
        ...     tokens2 = indexer.instance(["hello", "there"])
        >>> # Vocabulary is now frozen, use for inference
        >>> test_tokens = indexer.instance(["hello", "unknown"])

    Note:
        Training mode is reentrant but nested calls won't trigger hooks again.

    """
    original = self._training
    self._training = True
    try:
        if not original:
            self._on_start_training()
        yield
        if not original:
            self._on_end_training()
    finally:
        self._training = original

save

save(directory)

Save the transform to a directory using cloudpickle.

PARAMETER DESCRIPTION
directory

Directory path to save the transform.

TYPE: str | PathLike

Note

The transform is saved as 'transform.pkl' in the specified directory. cloudpickle is used to handle complex objects like lambdas and closures.

Source code in src/formed/integrations/ml/transforms/base.py
631
632
633
634
635
636
637
638
639
640
641
642
643
644
def save(self, directory: str | PathLike) -> None:
    """Save the transform to a directory using cloudpickle.

    Args:
        directory: Directory path to save the transform.

    Note:
        The transform is saved as 'transform.pkl' in the specified directory.
        cloudpickle is used to handle complex objects like lambdas and closures.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("wb") as f:
        cloudpickle.dump(self, f)

load classmethod

load(directory)

Load a transform from a directory.

PARAMETER DESCRIPTION
directory

Directory path containing the saved transform.

TYPE: str | PathLike

RETURNS DESCRIPTION
Self

The loaded transform instance.

RAISES DESCRIPTION
TypeError

If the loaded object is not an instance of this class.

Note

Expects a 'transform.pkl' file in the specified directory.

Source code in src/formed/integrations/ml/transforms/base.py
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
@classmethod
def load(cls, directory: str | PathLike) -> Self:
    """Load a transform from a directory.

    Args:
        directory: Directory path containing the saved transform.

    Returns:
        The loaded transform instance.

    Raises:
        TypeError: If the loaded object is not an instance of this class.

    Note:
        Expects a 'transform.pkl' file in the specified directory.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("rb") as f:
        obj = cloudpickle.load(f)
    if not isinstance(obj, cls):
        raise TypeError(f"Loaded object is not an instance of {cls.__name__}")
    return obj

LabelIndexer

Bases: BaseTransform[_S, LabelT, int, ndarray], Generic[_S, LabelT]

Map labels to integer indices with vocabulary building and statistics tracking.

LabelIndexer maintains a bidirectional mapping between labels and integer indices. In training mode, it dynamically builds the label vocabulary and tracks label frequencies. The vocabulary can be frozen to prevent changes during inference.

CLASS TYPE PARAMETER DESCRIPTION
_S

Source data type before accessor

LabelT

Label type (must be hashable)

ATTRIBUTE DESCRIPTION
label2id

Pre-defined label-to-index mapping. If empty, built during training.

TYPE: Sequence[tuple[LabelT, int]]

freeze

If True, prevent vocabulary updates even in training mode.

TYPE: bool

Properties

num_labels: Total number of unique labels in vocabulary. labels: List of labels sorted by their indices. occurrences: Dictionary mapping labels to their occurrence counts. distribution: Smoothed probability distribution over labels.

Examples:

>>> # Dynamic vocabulary building
>>> indexer = LabelIndexer()
>>> with indexer.train():
...     idx1 = indexer.instance("positive")  # 0
...     idx2 = indexer.instance("negative")  # 1
...     idx3 = indexer.instance("positive")  # 0 (already in vocab)
>>> print(indexer.labels)  # ["positive", "negative"]
>>> print(indexer.occurrences)  # {"positive": 2, "negative": 1}
>>>
>>> # Pre-defined vocabulary
>>> indexer = LabelIndexer(label2id=[("positive", 0), ("negative", 1)])
>>> idx = indexer.instance("positive")  # 0
>>>
>>> # Batching and reconstruction
>>> batch = indexer.batch([0, 1, 0])  # np.array([0, 1, 0])
>>> labels = indexer.reconstruct(batch)  # ["positive", "negative", "positive"]
Note
  • Raises KeyError if a label is not in vocabulary during inference
  • Use freeze=True to prevent accidental vocabulary updates
  • Distribution uses Laplace smoothing (add-one)

label2id class-attribute instance-attribute

label2id = field(default_factory=list)

freeze class-attribute instance-attribute

freeze = field(default=False)

num_labels property

num_labels

Get the total number of unique labels in the vocabulary.

labels property

labels

Get the list of labels sorted by their indices.

occurrences property

occurrences

Get the occurrence counts for each label seen during training.

distribution property

distribution

Get the smoothed probability distribution over labels.

Uses Laplace (add-one) smoothing to handle zero counts.

RETURNS DESCRIPTION
ndarray

Array of probabilities summing to 1.0, one per label.

accessor class-attribute instance-attribute

accessor = None

get_index

get_index(value)

Get the integer index for a label.

PARAMETER DESCRIPTION
value

The label to look up.

TYPE: LabelT

RETURNS DESCRIPTION
int

The integer index associated with the label.

RAISES DESCRIPTION
KeyError

If the label is not in the vocabulary.

Source code in src/formed/integrations/ml/transforms/basic.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def get_index(self, value: LabelT, /) -> int:
    """Get the integer index for a label.

    Args:
        value: The label to look up.

    Returns:
        The integer index associated with the label.

    Raises:
        KeyError: If the label is not in the vocabulary.

    """
    with suppress(StopIteration):
        return next(label_id for label, label_id in self.label2id if label == value)
    raise KeyError(value)

get_value

get_value(index)

Get the label for an integer index.

PARAMETER DESCRIPTION
index

The integer index to look up.

TYPE: int

RETURNS DESCRIPTION
LabelT

The label associated with the index.

RAISES DESCRIPTION
KeyError

If the index is not in the vocabulary.

Source code in src/formed/integrations/ml/transforms/basic.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def get_value(self, index: int, /) -> LabelT:
    """Get the label for an integer index.

    Args:
        index: The integer index to look up.

    Returns:
        The label associated with the index.

    Raises:
        KeyError: If the index is not in the vocabulary.

    """
    for label, label_id in self.label2id:
        if label_id == index:
            return label
    raise KeyError(index)

ingest

ingest(value)

Add a label to the vocabulary and update statistics.

This method is called internally during training to build the vocabulary and track label frequencies.

PARAMETER DESCRIPTION
value

The label to ingest.

TYPE: LabelT

Note

Only effective when in training mode and freeze=False. Logs a warning if called outside training mode.

Source code in src/formed/integrations/ml/transforms/basic.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def ingest(self, value: LabelT, /) -> None:
    """Add a label to the vocabulary and update statistics.

    This method is called internally during training to build the vocabulary
    and track label frequencies.

    Args:
        value: The label to ingest.

    Note:
        Only effective when in training mode and freeze=False.
        Logs a warning if called outside training mode.

    """
    if self.freeze:
        return
    if self._training:
        try:
            self.get_index(value)
        except KeyError:
            self.label2id = list(self.label2id) + [(value, len(self.label2id))]
            self._label_counts.append((value, 0))
        for index, (label, count) in enumerate(self._label_counts):
            if label == value:
                self._label_counts[index] = (label, count + 1)
                break
    else:
        logger.warning("Ignoring ingest call when not in training mode")

instance

instance(label)
Source code in src/formed/integrations/ml/transforms/basic.py
238
239
240
241
def instance(self, label: LabelT, /) -> int:
    if self._training:
        self.ingest(label)
    return self.get_index(label)

batch

batch(batch)
Source code in src/formed/integrations/ml/transforms/basic.py
243
244
def batch(self, batch: Sequence[int], /) -> numpy.ndarray:
    return numpy.array(batch, dtype=numpy.int64)

reconstruct

reconstruct(batch)

Convert a batch of indices back to labels.

PARAMETER DESCRIPTION
batch

Array of integer indices.

TYPE: ndarray

RETURNS DESCRIPTION
list[LabelT]

List of labels corresponding to the indices.

Examples:

>>> indexer = LabelIndexer(label2id=[("cat", 0), ("dog", 1)])
>>> indices = numpy.array([0, 1, 0])
>>> labels = indexer.reconstruct(indices)
>>> print(labels)  # ["cat", "dog", "cat"]
Source code in src/formed/integrations/ml/transforms/basic.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
def reconstruct(self, batch: numpy.ndarray, /) -> list[LabelT]:
    """Convert a batch of indices back to labels.

    Args:
        batch: Array of integer indices.

    Returns:
        List of labels corresponding to the indices.

    Examples:
        >>> indexer = LabelIndexer(label2id=[("cat", 0), ("dog", 1)])
        >>> indices = numpy.array([0, 1, 0])
        >>> labels = indexer.reconstruct(indices)
        >>> print(labels)  # ["cat", "dog", "cat"]

    """
    return [self.get_value(index) for index in batch.tolist()]

train

train()

Context manager to enable training mode for stateful transformations.

In training mode, transforms can build state (e.g., vocabularies, statistics) from the training data. Hooks _on_start_training() and _on_end_training() are called at the beginning and end of the training context.

YIELDS DESCRIPTION
None

None

Examples:

>>> indexer = TokenSequenceIndexer()
>>> with indexer.train():
...     # Build vocabulary from training data
...     tokens1 = indexer.instance(["hello", "world"])
...     tokens2 = indexer.instance(["hello", "there"])
>>> # Vocabulary is now frozen, use for inference
>>> test_tokens = indexer.instance(["hello", "unknown"])
Note

Training mode is reentrant but nested calls won't trigger hooks again.

Source code in src/formed/integrations/ml/transforms/base.py
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
@contextmanager
def train(self) -> Iterator[None]:
    """Context manager to enable training mode for stateful transformations.

    In training mode, transforms can build state (e.g., vocabularies, statistics)
    from the training data. Hooks `_on_start_training()` and `_on_end_training()`
    are called at the beginning and end of the training context.

    Yields:
        None

    Examples:
        >>> indexer = TokenSequenceIndexer()
        >>> with indexer.train():
        ...     # Build vocabulary from training data
        ...     tokens1 = indexer.instance(["hello", "world"])
        ...     tokens2 = indexer.instance(["hello", "there"])
        >>> # Vocabulary is now frozen, use for inference
        >>> test_tokens = indexer.instance(["hello", "unknown"])

    Note:
        Training mode is reentrant but nested calls won't trigger hooks again.

    """
    original = self._training
    self._training = True
    try:
        if not original:
            self._on_start_training()
        yield
        if not original:
            self._on_end_training()
    finally:
        self._training = original

save

save(directory)

Save the transform to a directory using cloudpickle.

PARAMETER DESCRIPTION
directory

Directory path to save the transform.

TYPE: str | PathLike

Note

The transform is saved as 'transform.pkl' in the specified directory. cloudpickle is used to handle complex objects like lambdas and closures.

Source code in src/formed/integrations/ml/transforms/base.py
631
632
633
634
635
636
637
638
639
640
641
642
643
644
def save(self, directory: str | PathLike) -> None:
    """Save the transform to a directory using cloudpickle.

    Args:
        directory: Directory path to save the transform.

    Note:
        The transform is saved as 'transform.pkl' in the specified directory.
        cloudpickle is used to handle complex objects like lambdas and closures.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("wb") as f:
        cloudpickle.dump(self, f)

load classmethod

load(directory)

Load a transform from a directory.

PARAMETER DESCRIPTION
directory

Directory path containing the saved transform.

TYPE: str | PathLike

RETURNS DESCRIPTION
Self

The loaded transform instance.

RAISES DESCRIPTION
TypeError

If the loaded object is not an instance of this class.

Note

Expects a 'transform.pkl' file in the specified directory.

Source code in src/formed/integrations/ml/transforms/base.py
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
@classmethod
def load(cls, directory: str | PathLike) -> Self:
    """Load a transform from a directory.

    Args:
        directory: Directory path containing the saved transform.

    Returns:
        The loaded transform instance.

    Raises:
        TypeError: If the loaded object is not an instance of this class.

    Note:
        Expects a 'transform.pkl' file in the specified directory.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("rb") as f:
        obj = cloudpickle.load(f)
    if not isinstance(obj, cls):
        raise TypeError(f"Loaded object is not an instance of {cls.__name__}")
    return obj

ScalarTransform

Bases: Generic[_S], BaseTransform[_S, float, float, ndarray]

Transform scalar values into batched numpy arrays.

ScalarTransform is a simple pass-through transform that preserves scalar values during instance transformation and stacks them into a 1D numpy array during batching.

CLASS TYPE PARAMETER DESCRIPTION
_S

Source data type before accessor

Examples:

>>> transform = ScalarTransform(accessor="score")
>>> value = transform({"score": 0.85})  # Returns 0.85
>>> batch = transform.batch([0.85, 0.92, 0.78])
>>> print(batch)  # np.array([0.85, 0.92, 0.78], dtype=float32)
>>> print(batch.shape)  # (3,)
Note
  • Instance values remain as Python floats
  • Batch values are converted to float32 numpy arrays
  • Stateless transform, no training required

accessor class-attribute instance-attribute

accessor = None

instance

instance(value)
Source code in src/formed/integrations/ml/transforms/basic.py
293
294
def instance(self, value: float, /) -> float:
    return value

batch

batch(batch)
Source code in src/formed/integrations/ml/transforms/basic.py
296
297
def batch(self, batch: Sequence[float], /) -> numpy.ndarray:
    return numpy.array(batch, dtype=numpy.float32)

train

train()

Context manager to enable training mode for stateful transformations.

In training mode, transforms can build state (e.g., vocabularies, statistics) from the training data. Hooks _on_start_training() and _on_end_training() are called at the beginning and end of the training context.

YIELDS DESCRIPTION
None

None

Examples:

>>> indexer = TokenSequenceIndexer()
>>> with indexer.train():
...     # Build vocabulary from training data
...     tokens1 = indexer.instance(["hello", "world"])
...     tokens2 = indexer.instance(["hello", "there"])
>>> # Vocabulary is now frozen, use for inference
>>> test_tokens = indexer.instance(["hello", "unknown"])
Note

Training mode is reentrant but nested calls won't trigger hooks again.

Source code in src/formed/integrations/ml/transforms/base.py
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
@contextmanager
def train(self) -> Iterator[None]:
    """Context manager to enable training mode for stateful transformations.

    In training mode, transforms can build state (e.g., vocabularies, statistics)
    from the training data. Hooks `_on_start_training()` and `_on_end_training()`
    are called at the beginning and end of the training context.

    Yields:
        None

    Examples:
        >>> indexer = TokenSequenceIndexer()
        >>> with indexer.train():
        ...     # Build vocabulary from training data
        ...     tokens1 = indexer.instance(["hello", "world"])
        ...     tokens2 = indexer.instance(["hello", "there"])
        >>> # Vocabulary is now frozen, use for inference
        >>> test_tokens = indexer.instance(["hello", "unknown"])

    Note:
        Training mode is reentrant but nested calls won't trigger hooks again.

    """
    original = self._training
    self._training = True
    try:
        if not original:
            self._on_start_training()
        yield
        if not original:
            self._on_end_training()
    finally:
        self._training = original

save

save(directory)

Save the transform to a directory using cloudpickle.

PARAMETER DESCRIPTION
directory

Directory path to save the transform.

TYPE: str | PathLike

Note

The transform is saved as 'transform.pkl' in the specified directory. cloudpickle is used to handle complex objects like lambdas and closures.

Source code in src/formed/integrations/ml/transforms/base.py
631
632
633
634
635
636
637
638
639
640
641
642
643
644
def save(self, directory: str | PathLike) -> None:
    """Save the transform to a directory using cloudpickle.

    Args:
        directory: Directory path to save the transform.

    Note:
        The transform is saved as 'transform.pkl' in the specified directory.
        cloudpickle is used to handle complex objects like lambdas and closures.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("wb") as f:
        cloudpickle.dump(self, f)

load classmethod

load(directory)

Load a transform from a directory.

PARAMETER DESCRIPTION
directory

Directory path containing the saved transform.

TYPE: str | PathLike

RETURNS DESCRIPTION
Self

The loaded transform instance.

RAISES DESCRIPTION
TypeError

If the loaded object is not an instance of this class.

Note

Expects a 'transform.pkl' file in the specified directory.

Source code in src/formed/integrations/ml/transforms/base.py
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
@classmethod
def load(cls, directory: str | PathLike) -> Self:
    """Load a transform from a directory.

    Args:
        directory: Directory path containing the saved transform.

    Returns:
        The loaded transform instance.

    Raises:
        TypeError: If the loaded object is not an instance of this class.

    Note:
        Expects a 'transform.pkl' file in the specified directory.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("rb") as f:
        obj = cloudpickle.load(f)
    if not isinstance(obj, cls):
        raise TypeError(f"Loaded object is not an instance of {cls.__name__}")
    return obj

TensorTransform

Bases: Generic[_S], BaseTransform[_S, ndarray, ndarray, ndarray]

Transform numpy arrays into batched tensors.

TensorTransform preserves numpy arrays during instance transformation and stacks them along the batch dimension (axis 0) during batching. All arrays in a batch must have the same shape.

CLASS TYPE PARAMETER DESCRIPTION
_S

Source data type before accessor

Examples:

>>> import numpy as np
>>> transform = TensorTransform(accessor="features")
>>> arr = transform({"features": np.array([1.0, 2.0, 3.0])})
>>> print(arr)  # np.array([1.0, 2.0, 3.0])
>>>
>>> batch = transform.batch([
...     np.array([1.0, 2.0, 3.0]),
...     np.array([4.0, 5.0, 6.0]),
... ])
>>> print(batch.shape)  # (2, 3)
Note
  • Requires all arrays in a batch to have compatible shapes
  • Stacks along axis 0 (batch dimension)
  • Stateless transform, no training required
RAISES DESCRIPTION
ValueError

If arrays have incompatible shapes for stacking.

accessor class-attribute instance-attribute

accessor = None

instance

instance(value)
Source code in src/formed/integrations/ml/transforms/basic.py
336
337
def instance(self, value: numpy.ndarray, /) -> numpy.ndarray:
    return value

batch

batch(batch)
Source code in src/formed/integrations/ml/transforms/basic.py
339
340
def batch(self, batch: Sequence[numpy.ndarray], /) -> numpy.ndarray:
    return numpy.stack(batch, axis=0)

train

train()

Context manager to enable training mode for stateful transformations.

In training mode, transforms can build state (e.g., vocabularies, statistics) from the training data. Hooks _on_start_training() and _on_end_training() are called at the beginning and end of the training context.

YIELDS DESCRIPTION
None

None

Examples:

>>> indexer = TokenSequenceIndexer()
>>> with indexer.train():
...     # Build vocabulary from training data
...     tokens1 = indexer.instance(["hello", "world"])
...     tokens2 = indexer.instance(["hello", "there"])
>>> # Vocabulary is now frozen, use for inference
>>> test_tokens = indexer.instance(["hello", "unknown"])
Note

Training mode is reentrant but nested calls won't trigger hooks again.

Source code in src/formed/integrations/ml/transforms/base.py
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
@contextmanager
def train(self) -> Iterator[None]:
    """Context manager to enable training mode for stateful transformations.

    In training mode, transforms can build state (e.g., vocabularies, statistics)
    from the training data. Hooks `_on_start_training()` and `_on_end_training()`
    are called at the beginning and end of the training context.

    Yields:
        None

    Examples:
        >>> indexer = TokenSequenceIndexer()
        >>> with indexer.train():
        ...     # Build vocabulary from training data
        ...     tokens1 = indexer.instance(["hello", "world"])
        ...     tokens2 = indexer.instance(["hello", "there"])
        >>> # Vocabulary is now frozen, use for inference
        >>> test_tokens = indexer.instance(["hello", "unknown"])

    Note:
        Training mode is reentrant but nested calls won't trigger hooks again.

    """
    original = self._training
    self._training = True
    try:
        if not original:
            self._on_start_training()
        yield
        if not original:
            self._on_end_training()
    finally:
        self._training = original

save

save(directory)

Save the transform to a directory using cloudpickle.

PARAMETER DESCRIPTION
directory

Directory path to save the transform.

TYPE: str | PathLike

Note

The transform is saved as 'transform.pkl' in the specified directory. cloudpickle is used to handle complex objects like lambdas and closures.

Source code in src/formed/integrations/ml/transforms/base.py
631
632
633
634
635
636
637
638
639
640
641
642
643
644
def save(self, directory: str | PathLike) -> None:
    """Save the transform to a directory using cloudpickle.

    Args:
        directory: Directory path to save the transform.

    Note:
        The transform is saved as 'transform.pkl' in the specified directory.
        cloudpickle is used to handle complex objects like lambdas and closures.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("wb") as f:
        cloudpickle.dump(self, f)

load classmethod

load(directory)

Load a transform from a directory.

PARAMETER DESCRIPTION
directory

Directory path containing the saved transform.

TYPE: str | PathLike

RETURNS DESCRIPTION
Self

The loaded transform instance.

RAISES DESCRIPTION
TypeError

If the loaded object is not an instance of this class.

Note

Expects a 'transform.pkl' file in the specified directory.

Source code in src/formed/integrations/ml/transforms/base.py
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
@classmethod
def load(cls, directory: str | PathLike) -> Self:
    """Load a transform from a directory.

    Args:
        directory: Directory path containing the saved transform.

    Returns:
        The loaded transform instance.

    Raises:
        TypeError: If the loaded object is not an instance of this class.

    Note:
        Expects a 'transform.pkl' file in the specified directory.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("rb") as f:
        obj = cloudpickle.load(f)
    if not isinstance(obj, cls):
        raise TypeError(f"Loaded object is not an instance of {cls.__name__}")
    return obj

VariableTensorTransform

Bases: Generic[_S], BaseTransform[_S, ndarray, ndarray, VariableTensorBatch[ndarray]]

Transform variable-size numpy arrays into padded batched tensors.

VariableTensorTransform preserves numpy arrays during instance transformation and pads them to the maximum shape in the batch during batching. It returns a TensorBatch containing the padded tensor and a mask indicating valid data.

CLASS TYPE PARAMETER DESCRIPTION
_S

Source data type before accessor

Examples:

>>> import numpy as np
>>> transform = VariableTensorTransform(accessor="sequences")
>>> arr = transform({"sequences": np.array([1, 2, 3])})
>>> print(arr)  # np.array([1, 2, 3])
>>>
>>> batch = transform.batch([
...     np.array([1, 2, 3]),
...     np.array([4, 5]),
...     np.array([6, 7, 8, 9]),
... ])
>>> print(batch.tensor)
>>> # np.array([
>>> #   [1, 2, 3, 0],
>>> #   [4, 5, 0, 0],
>>> #   [6, 7, 8, 9]
>>> # ])
>>> print(batch.mask)
>>> # np.array([
>>> #   [True, True, True, False],
>>> #   [True, True, False, False],
>>> #   [True, True, True, True]
>>> # ])
Note
  • Pads arrays with zeros to match the maximum shape in the batch
  • Mask indicates which elements are valid (True) vs. padded (False)
  • Stateless transform, no training required

accessor class-attribute instance-attribute

accessor = None

instance

instance(value)
Source code in src/formed/integrations/ml/transforms/basic.py
388
389
def instance(self, value: numpy.ndarray, /) -> numpy.ndarray:
    return value

batch

batch(batch)
Source code in src/formed/integrations/ml/transforms/basic.py
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
def batch(self, batch: Sequence[numpy.ndarray], /) -> VariableTensorBatch[numpy.ndarray]:
    if len(batch) == 0:
        return VariableTensorBatch(
            tensor=numpy.array([], dtype=numpy.float32),
            mask=numpy.array([], dtype=numpy.bool_),
        )
    max_ndim = max(arr.ndim for arr in batch)
    max_shape = []
    for dim in range(max_ndim):
        dim_sizes = [arr.shape[dim] if dim < arr.ndim else 1 for arr in batch]
        max_shape.append(max(dim_sizes))

    tensor = numpy.zeros((len(batch), *max_shape), dtype=batch[0].dtype)
    mask = numpy.zeros((len(batch), *max_shape), dtype=numpy.bool_)

    for i, arr in enumerate(batch):
        slices = tuple(slice(0, dim_size) for dim_size in arr.shape)
        tensor[i][slices] = arr
        mask[i][slices] = True

    return VariableTensorBatch[numpy.ndarray](tensor=tensor, mask=mask)

train

train()

Context manager to enable training mode for stateful transformations.

In training mode, transforms can build state (e.g., vocabularies, statistics) from the training data. Hooks _on_start_training() and _on_end_training() are called at the beginning and end of the training context.

YIELDS DESCRIPTION
None

None

Examples:

>>> indexer = TokenSequenceIndexer()
>>> with indexer.train():
...     # Build vocabulary from training data
...     tokens1 = indexer.instance(["hello", "world"])
...     tokens2 = indexer.instance(["hello", "there"])
>>> # Vocabulary is now frozen, use for inference
>>> test_tokens = indexer.instance(["hello", "unknown"])
Note

Training mode is reentrant but nested calls won't trigger hooks again.

Source code in src/formed/integrations/ml/transforms/base.py
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
@contextmanager
def train(self) -> Iterator[None]:
    """Context manager to enable training mode for stateful transformations.

    In training mode, transforms can build state (e.g., vocabularies, statistics)
    from the training data. Hooks `_on_start_training()` and `_on_end_training()`
    are called at the beginning and end of the training context.

    Yields:
        None

    Examples:
        >>> indexer = TokenSequenceIndexer()
        >>> with indexer.train():
        ...     # Build vocabulary from training data
        ...     tokens1 = indexer.instance(["hello", "world"])
        ...     tokens2 = indexer.instance(["hello", "there"])
        >>> # Vocabulary is now frozen, use for inference
        >>> test_tokens = indexer.instance(["hello", "unknown"])

    Note:
        Training mode is reentrant but nested calls won't trigger hooks again.

    """
    original = self._training
    self._training = True
    try:
        if not original:
            self._on_start_training()
        yield
        if not original:
            self._on_end_training()
    finally:
        self._training = original

save

save(directory)

Save the transform to a directory using cloudpickle.

PARAMETER DESCRIPTION
directory

Directory path to save the transform.

TYPE: str | PathLike

Note

The transform is saved as 'transform.pkl' in the specified directory. cloudpickle is used to handle complex objects like lambdas and closures.

Source code in src/formed/integrations/ml/transforms/base.py
631
632
633
634
635
636
637
638
639
640
641
642
643
644
def save(self, directory: str | PathLike) -> None:
    """Save the transform to a directory using cloudpickle.

    Args:
        directory: Directory path to save the transform.

    Note:
        The transform is saved as 'transform.pkl' in the specified directory.
        cloudpickle is used to handle complex objects like lambdas and closures.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("wb") as f:
        cloudpickle.dump(self, f)

load classmethod

load(directory)

Load a transform from a directory.

PARAMETER DESCRIPTION
directory

Directory path containing the saved transform.

TYPE: str | PathLike

RETURNS DESCRIPTION
Self

The loaded transform instance.

RAISES DESCRIPTION
TypeError

If the loaded object is not an instance of this class.

Note

Expects a 'transform.pkl' file in the specified directory.

Source code in src/formed/integrations/ml/transforms/base.py
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
@classmethod
def load(cls, directory: str | PathLike) -> Self:
    """Load a transform from a directory.

    Args:
        directory: Directory path containing the saved transform.

    Returns:
        The loaded transform instance.

    Raises:
        TypeError: If the loaded object is not an instance of this class.

    Note:
        Expects a 'transform.pkl' file in the specified directory.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("rb") as f:
        obj = cloudpickle.load(f)
    if not isinstance(obj, cls):
        raise TypeError(f"Loaded object is not an instance of {cls.__name__}")
    return obj

TensorSequenceTransform

Bases: Generic[_S], BaseTransform[_S, Sequence[ndarray], ndarray, VariableTensorBatch[ndarray]]

Transform sequences of numpy arrays into padded batched tensors.

TensorSequenceTransform handles sequences of arrays (e.g., token-level embeddings) by stacking them into a 2D array during instance transformation, then padding across the batch dimension during batching.

This is useful for token-level features where each token has its own vector, and different instances may have different numbers of tokens.

CLASS TYPE PARAMETER DESCRIPTION
_S

Source data type before accessor

Examples:

>>> import numpy as np
>>> transform = TensorSequenceTransform(accessor="token_vectors")
>>>
>>> # Each instance has a sequence of token vectors
>>> data1 = {"token_vectors": [np.array([1.0, 2.0]), np.array([3.0, 4.0])]}
>>> data2 = {"token_vectors": [np.array([5.0, 6.0])]}
>>>
>>> instance1 = transform(data1)  # Shape: (2, 2) - 2 tokens, 2 dims
>>> instance2 = transform(data2)  # Shape: (1, 2) - 1 token, 2 dims
>>>
>>> batch = transform.batch([instance1, instance2])
>>> print(batch.tensor.shape)  # (2, 2, 2) - batch_size, max_tokens, dims
>>> print(batch.mask.shape)    # (2, 2, 2)
>>> print(batch.mask[0])  # [[True, True], [True, True]]
>>> print(batch.mask[1])  # [[True, True], [False, False]]
Note
  • Converts sequence of arrays to 2D array during instance transformation
  • Pads to max token count during batching
  • Mask indicates valid tokens vs. padding
  • Empty sequences result in empty arrays

accessor class-attribute instance-attribute

accessor = None

instance

instance(value)
Source code in src/formed/integrations/ml/transforms/basic.py
456
457
458
459
460
def instance(self, value: Sequence[numpy.ndarray], /) -> numpy.ndarray:
    if len(value) == 0:
        return numpy.array([], dtype=numpy.float32)
    # Stack sequence of arrays into 2D array (num_tokens, embedding_dim)
    return numpy.stack([numpy.asarray(arr) for arr in value], axis=0)

batch

batch(batch)
Source code in src/formed/integrations/ml/transforms/basic.py
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
def batch(self, batch: Sequence[numpy.ndarray], /) -> VariableTensorBatch[numpy.ndarray]:
    if len(batch) == 0:
        return VariableTensorBatch(
            tensor=numpy.array([], dtype=numpy.float32),
            mask=numpy.array([], dtype=numpy.bool_),
        )
    max_ndim = max(arr.ndim for arr in batch)
    max_shape = []
    for dim in range(max_ndim):
        dim_sizes = [arr.shape[dim] if dim < arr.ndim else 1 for arr in batch]
        max_shape.append(max(dim_sizes))

    tensor = numpy.zeros((len(batch), *max_shape), dtype=batch[0].dtype)
    mask = numpy.zeros((len(batch), *max_shape), dtype=numpy.bool_)

    for i, arr in enumerate(batch):
        slices = tuple(slice(0, dim_size) for dim_size in arr.shape)
        tensor[i][slices] = arr
        mask[i][slices] = True

    return VariableTensorBatch[numpy.ndarray](tensor=tensor, mask=mask)

train

train()

Context manager to enable training mode for stateful transformations.

In training mode, transforms can build state (e.g., vocabularies, statistics) from the training data. Hooks _on_start_training() and _on_end_training() are called at the beginning and end of the training context.

YIELDS DESCRIPTION
None

None

Examples:

>>> indexer = TokenSequenceIndexer()
>>> with indexer.train():
...     # Build vocabulary from training data
...     tokens1 = indexer.instance(["hello", "world"])
...     tokens2 = indexer.instance(["hello", "there"])
>>> # Vocabulary is now frozen, use for inference
>>> test_tokens = indexer.instance(["hello", "unknown"])
Note

Training mode is reentrant but nested calls won't trigger hooks again.

Source code in src/formed/integrations/ml/transforms/base.py
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
@contextmanager
def train(self) -> Iterator[None]:
    """Context manager to enable training mode for stateful transformations.

    In training mode, transforms can build state (e.g., vocabularies, statistics)
    from the training data. Hooks `_on_start_training()` and `_on_end_training()`
    are called at the beginning and end of the training context.

    Yields:
        None

    Examples:
        >>> indexer = TokenSequenceIndexer()
        >>> with indexer.train():
        ...     # Build vocabulary from training data
        ...     tokens1 = indexer.instance(["hello", "world"])
        ...     tokens2 = indexer.instance(["hello", "there"])
        >>> # Vocabulary is now frozen, use for inference
        >>> test_tokens = indexer.instance(["hello", "unknown"])

    Note:
        Training mode is reentrant but nested calls won't trigger hooks again.

    """
    original = self._training
    self._training = True
    try:
        if not original:
            self._on_start_training()
        yield
        if not original:
            self._on_end_training()
    finally:
        self._training = original

save

save(directory)

Save the transform to a directory using cloudpickle.

PARAMETER DESCRIPTION
directory

Directory path to save the transform.

TYPE: str | PathLike

Note

The transform is saved as 'transform.pkl' in the specified directory. cloudpickle is used to handle complex objects like lambdas and closures.

Source code in src/formed/integrations/ml/transforms/base.py
631
632
633
634
635
636
637
638
639
640
641
642
643
644
def save(self, directory: str | PathLike) -> None:
    """Save the transform to a directory using cloudpickle.

    Args:
        directory: Directory path to save the transform.

    Note:
        The transform is saved as 'transform.pkl' in the specified directory.
        cloudpickle is used to handle complex objects like lambdas and closures.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("wb") as f:
        cloudpickle.dump(self, f)

load classmethod

load(directory)

Load a transform from a directory.

PARAMETER DESCRIPTION
directory

Directory path containing the saved transform.

TYPE: str | PathLike

RETURNS DESCRIPTION
Self

The loaded transform instance.

RAISES DESCRIPTION
TypeError

If the loaded object is not an instance of this class.

Note

Expects a 'transform.pkl' file in the specified directory.

Source code in src/formed/integrations/ml/transforms/base.py
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
@classmethod
def load(cls, directory: str | PathLike) -> Self:
    """Load a transform from a directory.

    Args:
        directory: Directory path containing the saved transform.

    Returns:
        The loaded transform instance.

    Raises:
        TypeError: If the loaded object is not an instance of this class.

    Note:
        Expects a 'transform.pkl' file in the specified directory.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("rb") as f:
        obj = cloudpickle.load(f)
    if not isinstance(obj, cls):
        raise TypeError(f"Loaded object is not an instance of {cls.__name__}")
    return obj

formed.integrations.ml.transforms.nlp

NLP-specific data transformations for text processing.

This module provides transformations for natural language processing tasks, including tokenization, vocabulary building, and sequence indexing with special tokens (PAD, UNK, BOS, EOS).

Available Transforms
  • TokenSequenceIndexer: Convert token sequences to integer indices with vocab building
  • TokenCharactersIndexer: Character-level indexing for tokens
  • Tokenizer: Complete tokenization pipeline with surfaces, postags, and characters
Features
  • Dynamic vocabulary building with min_df/max_df filtering
  • Document frequency tracking
  • Special token handling (PAD, UNK, BOS, EOS)
  • Automatic padding and masking
  • Reconstruction support (indices -> tokens)

Examples:

>>> from formed.integrations.ml import Tokenizer, TokenSequenceIndexer
>>>
>>> # Simple tokenization
>>> tokenizer = Tokenizer(surfaces=TokenSequenceIndexer(
...     unk_token="<UNK>", pad_token="<PAD>"
... ))
>>>
>>> with tokenizer.train():
...     instance = tokenizer.instance("Hello world!")
>>> batch = tokenizer.batch([instance1, instance2, instance3])
>>> print(batch.surfaces.ids.shape)  # (3, max_length)
>>> print(batch.surfaces.mask.shape)  # (3, max_length)
>>>
>>> # Reconstruct tokens from indices
>>> tokens = tokenizer.surfaces.reconstruct(batch.surfaces)

logger module-attribute

logger = getLogger(__name__)

TokenSequenceIndexer

Bases: BaseTransform[_S, Sequence[str], Sequence[str], IDSequenceBatch], Generic[_S]

Convert token sequences to integer indices with vocabulary building and filtering.

TokenSequenceIndexer builds and maintains a vocabulary, converting tokens to indices. It supports vocabulary filtering by document frequency (min_df/max_df), vocabulary size limits, and special tokens (PAD, UNK, BOS, EOS). During batching, sequences are padded to the same length and a mask is generated.

CLASS TYPE PARAMETER DESCRIPTION
_S

Source data type before accessor

ATTRIBUTE DESCRIPTION
vocab

Pre-defined or built vocabulary mapping tokens to indices.

TYPE: Mapping[str, int]

pad_token

Padding token (required).

TYPE: str

unk_token

Unknown token for out-of-vocabulary tokens (optional).

TYPE: str | None

bos_token

Beginning-of-sequence token (optional).

TYPE: str | None

eos_token

End-of-sequence token (optional).

TYPE: str | None

min_df

Minimum document frequency (int or fraction) to include token.

TYPE: int | float

max_df

Maximum document frequency (int or fraction) to include token.

TYPE: int | float

max_vocab_size

Maximum vocabulary size (excluding special tokens).

TYPE: int | None

freeze

If True, prevent vocabulary updates.

TYPE: bool

Examples:

>>> # Build vocabulary with filtering
>>> indexer = TokenSequenceIndexer(
...     unk_token="<UNK>",
...     min_df=2,  # Tokens must appear in at least 2 documents
...     max_vocab_size=10000
... )
>>>
>>> with indexer.train():
...     tokens1 = indexer.instance(["hello", "world"])
...     tokens2 = indexer.instance(["hello", "there"])
>>> # Vocabulary: {"<PAD>": 0, "<UNK>": 1, "hello": 2, "world": 3, "there": 4}
>>>
>>> # Create batch with padding
>>> batch = indexer.batch([
...     ["hello", "world"],
...     ["hello", "there", "friend"]
... ])
>>> print(batch.ids.shape)  # (2, 3) - padded to max length
>>> print(batch.mask.shape)  # (2, 3) - True for real tokens
>>>
>>> # Reconstruct original tokens
>>> tokens = indexer.reconstruct(batch)
>>> print(tokens)  # [["hello", "world"], ["hello", "there", "friend"]]
Note
  • Special tokens are always added first and never filtered
  • min_df/max_df require unk_token to handle filtered tokens
  • Document frequency counts unique tokens per document
  • BOS/EOS tokens are added during batching if specified
  • Reconstruction removes special tokens and padding
RAISES DESCRIPTION
ValueError

If configuration is invalid (e.g., min_df>1 without unk_token).

vocab class-attribute instance-attribute

vocab = field(default_factory=dict)

pad_token class-attribute instance-attribute

pad_token = '<PAD>'

unk_token class-attribute instance-attribute

unk_token = None

bos_token class-attribute instance-attribute

bos_token = None

eos_token class-attribute instance-attribute

eos_token = None

min_df class-attribute instance-attribute

min_df = 1

max_df class-attribute instance-attribute

max_df = 1.0

max_vocab_size class-attribute instance-attribute

max_vocab_size = None

freeze class-attribute instance-attribute

freeze = False

pad_index property

pad_index

Index of the padding token.

unk_index property

unk_index

Index of the unknown token, or None if not set.

bos_index property

bos_index

Index of the beginning-of-sequence token, or None if not set.

eos_index property

eos_index

Index of the end-of-sequence token, or None if not set.

vocab_size property

vocab_size

Total number of tokens in the vocabulary.

accessor class-attribute instance-attribute

accessor = None

get_index

get_index(value)

Get the index of a token, using unk_token if not found.

Source code in src/formed/integrations/ml/transforms/nlp.py
235
236
237
238
239
240
241
def get_index(self, value: str, /) -> int:
    """Get the index of a token, using unk_token if not found."""
    if value in self.vocab:
        return self.vocab[value]
    if self.unk_token is not None:
        return self.vocab[self.unk_token]
    raise KeyError(value)

get_value

get_value(index)

Get the token corresponding to an index.

Source code in src/formed/integrations/ml/transforms/nlp.py
243
244
245
246
247
def get_value(self, index: int, /) -> str:
    """Get the token corresponding to an index."""
    if index in self._inverted_vocab:
        return self._inverted_vocab[index]
    raise KeyError(index)

ingest

ingest(values)

Ingest a sequence of tokens to update counts for vocabulary building.

Source code in src/formed/integrations/ml/transforms/nlp.py
249
250
251
252
253
254
255
256
257
258
259
260
def ingest(self, values: Sequence[str], /) -> None:
    """Ingest a sequence of tokens to update counts for vocabulary building."""
    if self.freeze:
        return
    if self._training:
        for token in values:
            self._token_counts[token] = self._token_counts.get(token, 0) + 1
        self._document_count += 1
        for toke in set(values):
            self._document_frequencies[toke] = self._document_frequencies.get(toke, 0) + 1
    else:
        logger.warning("Ignoring ingest call when not in training mode")

instance

instance(tokens)
Source code in src/formed/integrations/ml/transforms/nlp.py
262
263
264
265
def instance(self, tokens: Sequence[str], /) -> Sequence[str]:
    if self._training:
        self.ingest(tokens)
    return tokens

batch

batch(batch)
Source code in src/formed/integrations/ml/transforms/nlp.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def batch(self, batch: Sequence[Sequence[str]], /) -> IDSequenceBatch:
    batch_size = len(batch)
    max_length = max(len(tokens) for tokens in batch)
    if self.bos_token is not None:
        max_length += 1
    if self.eos_token is not None:
        max_length += 1
    ids = numpy.full((batch_size, max_length), self.pad_index, dtype=numpy.int64)
    mask = numpy.zeros((batch_size, max_length), dtype=numpy.bool_)
    for i, tokens in enumerate(batch):
        indices = [self.get_index(token) for token in tokens]
        if self.bos_token is not None:
            indices = [self.vocab[self.bos_token]] + indices
        if self.eos_token is not None:
            indices = indices + [self.vocab[self.eos_token]]
        length = len(indices)
        ids[i, :length] = indices
        mask[i, :length] = 1
    return IDSequenceBatch(ids=ids, mask=mask)

reconstruct

reconstruct(batch)

Reconstruct token sequences from a batch of indices.

Source code in src/formed/integrations/ml/transforms/nlp.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
def reconstruct(self, batch: IDSequenceBatch, /) -> list[Sequence[str]]:
    """Reconstruct token sequences from a batch of indices."""
    sequences = []
    for i in range(batch.ids.shape[0]):
        length = int(batch.mask[i].sum())
        indices = batch.ids[i, :length].tolist()
        tokens = [self.get_value(index) for index in indices]
        if tokens and tokens[0] == self.bos_token:
            tokens = tokens[1:]
        if tokens and tokens[-1] == self.eos_token:
            tokens = tokens[:-1]
        tokens = [token for token in tokens if token != self.pad_token]
        sequences.append(tokens)
    return sequences

train

train()

Context manager to enable training mode for stateful transformations.

In training mode, transforms can build state (e.g., vocabularies, statistics) from the training data. Hooks _on_start_training() and _on_end_training() are called at the beginning and end of the training context.

YIELDS DESCRIPTION
None

None

Examples:

>>> indexer = TokenSequenceIndexer()
>>> with indexer.train():
...     # Build vocabulary from training data
...     tokens1 = indexer.instance(["hello", "world"])
...     tokens2 = indexer.instance(["hello", "there"])
>>> # Vocabulary is now frozen, use for inference
>>> test_tokens = indexer.instance(["hello", "unknown"])
Note

Training mode is reentrant but nested calls won't trigger hooks again.

Source code in src/formed/integrations/ml/transforms/base.py
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
@contextmanager
def train(self) -> Iterator[None]:
    """Context manager to enable training mode for stateful transformations.

    In training mode, transforms can build state (e.g., vocabularies, statistics)
    from the training data. Hooks `_on_start_training()` and `_on_end_training()`
    are called at the beginning and end of the training context.

    Yields:
        None

    Examples:
        >>> indexer = TokenSequenceIndexer()
        >>> with indexer.train():
        ...     # Build vocabulary from training data
        ...     tokens1 = indexer.instance(["hello", "world"])
        ...     tokens2 = indexer.instance(["hello", "there"])
        >>> # Vocabulary is now frozen, use for inference
        >>> test_tokens = indexer.instance(["hello", "unknown"])

    Note:
        Training mode is reentrant but nested calls won't trigger hooks again.

    """
    original = self._training
    self._training = True
    try:
        if not original:
            self._on_start_training()
        yield
        if not original:
            self._on_end_training()
    finally:
        self._training = original

save

save(directory)

Save the transform to a directory using cloudpickle.

PARAMETER DESCRIPTION
directory

Directory path to save the transform.

TYPE: str | PathLike

Note

The transform is saved as 'transform.pkl' in the specified directory. cloudpickle is used to handle complex objects like lambdas and closures.

Source code in src/formed/integrations/ml/transforms/base.py
631
632
633
634
635
636
637
638
639
640
641
642
643
644
def save(self, directory: str | PathLike) -> None:
    """Save the transform to a directory using cloudpickle.

    Args:
        directory: Directory path to save the transform.

    Note:
        The transform is saved as 'transform.pkl' in the specified directory.
        cloudpickle is used to handle complex objects like lambdas and closures.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("wb") as f:
        cloudpickle.dump(self, f)

load classmethod

load(directory)

Load a transform from a directory.

PARAMETER DESCRIPTION
directory

Directory path containing the saved transform.

TYPE: str | PathLike

RETURNS DESCRIPTION
Self

The loaded transform instance.

RAISES DESCRIPTION
TypeError

If the loaded object is not an instance of this class.

Note

Expects a 'transform.pkl' file in the specified directory.

Source code in src/formed/integrations/ml/transforms/base.py
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
@classmethod
def load(cls, directory: str | PathLike) -> Self:
    """Load a transform from a directory.

    Args:
        directory: Directory path containing the saved transform.

    Returns:
        The loaded transform instance.

    Raises:
        TypeError: If the loaded object is not an instance of this class.

    Note:
        Expects a 'transform.pkl' file in the specified directory.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("rb") as f:
        obj = cloudpickle.load(f)
    if not isinstance(obj, cls):
        raise TypeError(f"Loaded object is not an instance of {cls.__name__}")
    return obj

TokenCharactersIndexer

Bases: TokenSequenceIndexer[_S], Generic[_S]

Character-level indexing for token sequences.

TokenCharactersIndexer extends TokenSequenceIndexer to index individual characters within tokens. This is useful for character-level models or handling rare words. The batch output is a 3D tensor: (batch_size, num_tokens, max_characters).

CLASS TYPE PARAMETER DESCRIPTION
_S

Source data type before accessor

ATTRIBUTE DESCRIPTION
min_characters

Minimum character length per token (for padding).

TYPE: int

Note

Inherits all attributes from TokenSequenceIndexer.

Examples:

>>> indexer = TokenCharactersIndexer(
...     unk_token="<UNK>",
...     bos_token="<BOS>",
...     eos_token="<EOS>"
... )
>>>
>>> with indexer.train():
...     tokens = indexer.instance(["hello", "world"])
>>>
>>> batch = indexer.batch([["hello", "world"], ["hi"]])
>>> print(batch.ids.shape)  # (2, 2, 7) - batch x tokens x chars
>>> print(batch.mask.shape)  # (2, 2, 7)
Note
  • Each token is converted to a sequence of character indices
  • BOS/EOS are added per token, not per sequence
  • Vocabulary contains individual characters, not tokens
  • Useful for morphologically rich languages or rare word handling

min_characters class-attribute instance-attribute

min_characters = 1

accessor class-attribute instance-attribute

accessor = None

vocab class-attribute instance-attribute

vocab = field(default_factory=dict)

pad_token class-attribute instance-attribute

pad_token = '<PAD>'

unk_token class-attribute instance-attribute

unk_token = None

bos_token class-attribute instance-attribute

bos_token = None

eos_token class-attribute instance-attribute

eos_token = None

min_df class-attribute instance-attribute

min_df = 1

max_df class-attribute instance-attribute

max_df = 1.0

max_vocab_size class-attribute instance-attribute

max_vocab_size = None

freeze class-attribute instance-attribute

freeze = False

pad_index property

pad_index

Index of the padding token.

unk_index property

unk_index

Index of the unknown token, or None if not set.

bos_index property

bos_index

Index of the beginning-of-sequence token, or None if not set.

eos_index property

eos_index

Index of the end-of-sequence token, or None if not set.

vocab_size property

vocab_size

Total number of tokens in the vocabulary.

ingest

ingest(values)
Source code in src/formed/integrations/ml/transforms/nlp.py
349
350
def ingest(self, values: Sequence[str], /) -> None:
    super().ingest("".join(values))

batch

batch(batch)
Source code in src/formed/integrations/ml/transforms/nlp.py
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
def batch(self, batch: Sequence[Sequence[str]], /) -> IDSequenceBatch:
    batch_size = len(batch)
    max_tokens = max(len(tokens) for tokens in batch)
    max_characters = max(self.min_characters, max(len(token) for tokens in batch for token in tokens))
    ids = numpy.full((batch_size, max_tokens, max_characters), self.pad_index, dtype=numpy.int64)
    mask = numpy.zeros((batch_size, max_tokens, max_characters), dtype=numpy.bool_)
    for i, tokens in enumerate(batch):
        for j, token in enumerate(tokens):
            indices = [self.get_index(char) for char in token]
            if self.bos_token is not None:
                indices = [self.vocab[self.bos_token]] + indices
            if self.eos_token is not None:
                indices = indices + [self.vocab[self.eos_token]]
            length = len(indices)
            ids[i, j, :length] = indices
            mask[i, j, :length] = 1
    return IDSequenceBatch(ids=ids, mask=mask)

reconstruct

reconstruct(batch)
Source code in src/formed/integrations/ml/transforms/nlp.py
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
def reconstruct(self, batch: IDSequenceBatch, /) -> list[Sequence[str]]:
    sequences = []
    for i in range(batch.ids.shape[0]):
        token_indices = batch.ids[i]
        token_mask = batch.mask[i]
        tokens = []
        for j in range(token_indices.shape[0]):
            if not token_mask[j].any():
                break
            length = int(token_mask[j].sum())
            char_indices = token_indices[j, :length].tolist()
            chars = [self.get_value(index) for index in char_indices]
            tokens.append("".join(chars))
        if tokens and tokens[0] == self.bos_token:
            tokens = tokens[1:]
        if tokens and tokens[-1] == self.eos_token:
            tokens = tokens[:-1]
        tokens = [token for token in tokens if token != self.pad_token]
        sequences.append(tokens)
    return sequences

instance

instance(tokens)
Source code in src/formed/integrations/ml/transforms/nlp.py
262
263
264
265
def instance(self, tokens: Sequence[str], /) -> Sequence[str]:
    if self._training:
        self.ingest(tokens)
    return tokens

train

train()

Context manager to enable training mode for stateful transformations.

In training mode, transforms can build state (e.g., vocabularies, statistics) from the training data. Hooks _on_start_training() and _on_end_training() are called at the beginning and end of the training context.

YIELDS DESCRIPTION
None

None

Examples:

>>> indexer = TokenSequenceIndexer()
>>> with indexer.train():
...     # Build vocabulary from training data
...     tokens1 = indexer.instance(["hello", "world"])
...     tokens2 = indexer.instance(["hello", "there"])
>>> # Vocabulary is now frozen, use for inference
>>> test_tokens = indexer.instance(["hello", "unknown"])
Note

Training mode is reentrant but nested calls won't trigger hooks again.

Source code in src/formed/integrations/ml/transforms/base.py
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
@contextmanager
def train(self) -> Iterator[None]:
    """Context manager to enable training mode for stateful transformations.

    In training mode, transforms can build state (e.g., vocabularies, statistics)
    from the training data. Hooks `_on_start_training()` and `_on_end_training()`
    are called at the beginning and end of the training context.

    Yields:
        None

    Examples:
        >>> indexer = TokenSequenceIndexer()
        >>> with indexer.train():
        ...     # Build vocabulary from training data
        ...     tokens1 = indexer.instance(["hello", "world"])
        ...     tokens2 = indexer.instance(["hello", "there"])
        >>> # Vocabulary is now frozen, use for inference
        >>> test_tokens = indexer.instance(["hello", "unknown"])

    Note:
        Training mode is reentrant but nested calls won't trigger hooks again.

    """
    original = self._training
    self._training = True
    try:
        if not original:
            self._on_start_training()
        yield
        if not original:
            self._on_end_training()
    finally:
        self._training = original

save

save(directory)

Save the transform to a directory using cloudpickle.

PARAMETER DESCRIPTION
directory

Directory path to save the transform.

TYPE: str | PathLike

Note

The transform is saved as 'transform.pkl' in the specified directory. cloudpickle is used to handle complex objects like lambdas and closures.

Source code in src/formed/integrations/ml/transforms/base.py
631
632
633
634
635
636
637
638
639
640
641
642
643
644
def save(self, directory: str | PathLike) -> None:
    """Save the transform to a directory using cloudpickle.

    Args:
        directory: Directory path to save the transform.

    Note:
        The transform is saved as 'transform.pkl' in the specified directory.
        cloudpickle is used to handle complex objects like lambdas and closures.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("wb") as f:
        cloudpickle.dump(self, f)

load classmethod

load(directory)

Load a transform from a directory.

PARAMETER DESCRIPTION
directory

Directory path containing the saved transform.

TYPE: str | PathLike

RETURNS DESCRIPTION
Self

The loaded transform instance.

RAISES DESCRIPTION
TypeError

If the loaded object is not an instance of this class.

Note

Expects a 'transform.pkl' file in the specified directory.

Source code in src/formed/integrations/ml/transforms/base.py
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
@classmethod
def load(cls, directory: str | PathLike) -> Self:
    """Load a transform from a directory.

    Args:
        directory: Directory path containing the saved transform.

    Returns:
        The loaded transform instance.

    Raises:
        TypeError: If the loaded object is not an instance of this class.

    Note:
        Expects a 'transform.pkl' file in the specified directory.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("rb") as f:
        obj = cloudpickle.load(f)
    if not isinstance(obj, cls):
        raise TypeError(f"Loaded object is not an instance of {cls.__name__}")
    return obj

get_index

get_index(value)

Get the index of a token, using unk_token if not found.

Source code in src/formed/integrations/ml/transforms/nlp.py
235
236
237
238
239
240
241
def get_index(self, value: str, /) -> int:
    """Get the index of a token, using unk_token if not found."""
    if value in self.vocab:
        return self.vocab[value]
    if self.unk_token is not None:
        return self.vocab[self.unk_token]
    raise KeyError(value)

get_value

get_value(index)

Get the token corresponding to an index.

Source code in src/formed/integrations/ml/transforms/nlp.py
243
244
245
246
247
def get_value(self, index: int, /) -> str:
    """Get the token corresponding to an index."""
    if index in self._inverted_vocab:
        return self._inverted_vocab[index]
    raise KeyError(index)

Tokenizer

Bases: DataModule[DataModuleModeT, Union[str, Sequence[str], AnalyzedText], 'Tokenizer[AsInstance]', 'Tokenizer[AsBatch]'], Generic[DataModuleModeT]

Complete tokenization pipeline with multiple representation options.

Tokenizer is a DataModule that provides a unified interface for text tokenization with support for surface forms, part-of-speech tags, and character-level representations. It accepts raw text, pre-tokenized sequences, or analyzed text and converts them to indexed representations suitable for neural models.

CLASS TYPE PARAMETER DESCRIPTION
DataModuleModeT

Current mode (AsConverter, AsInstance, or AsBatch)

ATTRIBUTE DESCRIPTION
surfaces

Required token sequence indexer for surface forms (words).

TYPE: TokenSequenceIndexer

postags

Optional indexer for part-of-speech tags.

TYPE: Extra[TokenSequenceIndexer]

characters

Optional character-level indexer for tokens.

TYPE: Extra[TokenCharactersIndexer]

analyzer

Optional custom text analyzer/tokenizer function.

TYPE: Param[Callable[[str | Sequence[str] | AnalyzedText], AnalyzedText] | None]

Examples:

>>> # Basic tokenization
>>> tokenizer = Tokenizer(
...     surfaces=TokenSequenceIndexer(unk_token="<UNK>")
... )
>>>
>>> with tokenizer.train():
...     instance1 = tokenizer.instance("Hello world!")
...     instance2 = tokenizer.instance(["Hello", "world", "!"])
>>>
>>> batch = tokenizer.batch([instance1, instance2])
>>> print(batch.surfaces.ids.shape)  # (2, max_tokens)
>>> print(batch.surfaces.mask.shape)  # (2, max_tokens)
>>>
>>> # With POS tags and characters
>>> tokenizer = Tokenizer(
...     surfaces=TokenSequenceIndexer(unk_token="<UNK>"),
...     postags=TokenSequenceIndexer(unk_token="<UNK-POS>"),
...     characters=TokenCharactersIndexer()
... )
>>>
>>> analyzed = AnalyzedText(
...     surfaces=["Hello", "world"],
...     postags=["INTJ", "NOUN"]
... )
>>> instance = tokenizer.instance(analyzed)
>>> print(instance.surfaces)  # Indexed tokens
>>> print(instance.postags)  # Indexed POS tags
>>> print(instance.characters)  # Character indices
Note
  • Default analyzer uses punkt tokenization for raw strings
  • Accepts string, token list, or AnalyzedText as input
  • Extra fields (postags, characters) can be None
  • All indexers share the same training context

surfaces class-attribute instance-attribute

surfaces = field(default_factory=TokenSequenceIndexer)

postags class-attribute instance-attribute

postags = default(None)

characters class-attribute instance-attribute

characters = default(None)

text_vector class-attribute instance-attribute

text_vector = default(None)

token_vectors class-attribute instance-attribute

token_vectors = default(None)

analyzer class-attribute instance-attribute

analyzer = default(None)

accessor class-attribute instance-attribute

accessor = None

instance

instance(x)
Source code in src/formed/integrations/ml/transforms/nlp.py
471
472
473
def instance(self: "Tokenizer[AsConverter]", x: str | Sequence[str] | AnalyzedText, /) -> "Tokenizer[AsInstance]":
    analyzer = self.analyzer or self._default_analyzer
    return cast(DataModule[AsConverter], super()).instance(analyzer(x))

batch

batch(instances)

Collate multiple instances into a batched representation.

Takes a sequence of raw data or instances and creates a DataModule in AsBatch mode. Each transform field's batch() method is called to collate the corresponding field values.

PARAMETER DESCRIPTION
instances

Sequence of raw data or DataModule instances.

TYPE: Sequence[_T | _InstanceT]

RETURNS DESCRIPTION
_BatchT

A DataModule in AsBatch mode with batched tensor fields.

Examples:

>>> dm = TextDataModule(text=Tokenizer(), label=LabelIndexer())
>>> instances = [dm.instance(ex) for ex in examples]
>>> batch = dm.batch(instances)
>>> print(batch.text.surfaces.ids.shape)  # (batch_size, seq_length)
>>> print(batch.label.shape)  # (batch_size,)
>>> print(len(batch))  # batch_size
Note
  • Can only be called in AsConverter mode
  • Automatically converts raw data to instances if needed
  • Returns a new DataModule with mode=AsBatch
  • Extra fields are None if all instances have None for that field
Source code in src/formed/integrations/ml/transforms/base.py
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
def batch(self: "DataModule[AsConverter]", instances: Sequence[_T | _InstanceT]) -> _BatchT:
    """Collate multiple instances into a batched representation.

    Takes a sequence of raw data or instances and creates a `DataModule` in
    `AsBatch` mode. Each transform field's `batch()` method is called to collate
    the corresponding field values.

    Args:
        instances: Sequence of raw data or `DataModule` instances.

    Returns:
        A `DataModule` in `AsBatch` mode with batched tensor fields.

    Examples:
        >>> dm = TextDataModule(text=Tokenizer(), label=LabelIndexer())
        >>> instances = [dm.instance(ex) for ex in examples]
        >>> batch = dm.batch(instances)
        >>> print(batch.text.surfaces.ids.shape)  # (batch_size, seq_length)
        >>> print(batch.label.shape)  # (batch_size,)
        >>> print(len(batch))  # batch_size

    Note:
        - Can only be called in `AsConverter` mode
        - Automatically converts raw data to instances if needed
        - Returns a new `DataModule` with `mode=AsBatch`
        - Extra fields are `None` if all instances have None for that field

    """
    assert self.__mode__ in (None, DataModuleMode.AS_CONVERTER), (
        "DataModule must be in converter mode to create a batch"
    )

    instances = [item if isinstance(item, DataModule) else self.instance(item) for item in instances]
    fields = {}
    for name, transform in self.__field_transforms__.items():
        can_be_optional = name in self.__class__.__get_extra_fields__()
        values = [getattr(instance, name) for instance in instances]
        if can_be_optional and all(value is None for value in values):
            fields[name] = None
        else:
            fields[name] = transform.batch(values)
    for name in self.__class__.__get_param_fields__().keys():
        if name not in fields:
            fields[name] = _UNAVAILABLE

    batch = cast(_BatchT, dataclasses.replace(self, **fields))
    setattr(batch, "__mode__", DataModuleMode.AS_BATCH)

    batch._batch_size = len(instances)
    return batch

train

train()

Context manager to enable training mode for all field transforms.

This propagates training mode to all BaseTransform fields, allowing them to build state (e.g., vocabularies) from training data.

YIELDS DESCRIPTION
None

None

Examples:

>>> dm = TextDataModule(text=Tokenizer(), label=LabelIndexer())
>>> with dm.train():
...     instances = [dm.instance(example) for example in train_data]
>>> # Vocabularies are now built and frozen
Note

Can only be called in AsConverter mode.

Source code in src/formed/integrations/ml/transforms/base.py
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
@contextmanager
def train(self) -> Iterator[None]:
    """Context manager to enable training mode for all field transforms.

    This propagates training mode to all BaseTransform fields, allowing them
    to build state (e.g., vocabularies) from training data.

    Yields:
        None

    Examples:
        >>> dm = TextDataModule(text=Tokenizer(), label=LabelIndexer())
        >>> with dm.train():
        ...     instances = [dm.instance(example) for example in train_data]
        >>> # Vocabularies are now built and frozen

    Note:
        Can only be called in AsConverter mode.

    """
    assert self.__mode__ in (None, DataModuleMode.AS_CONVERTER), (
        "DataModule must be in converter mode to enter training mode"
    )
    with ExitStack() as stack:
        for transform in self.__field_transforms__.values():
            stack.enter_context(transform.train())
        yield

save

save(directory)

Save the transform to a directory using cloudpickle.

PARAMETER DESCRIPTION
directory

Directory path to save the transform.

TYPE: str | PathLike

Note

The transform is saved as 'transform.pkl' in the specified directory. cloudpickle is used to handle complex objects like lambdas and closures.

Source code in src/formed/integrations/ml/transforms/base.py
631
632
633
634
635
636
637
638
639
640
641
642
643
644
def save(self, directory: str | PathLike) -> None:
    """Save the transform to a directory using cloudpickle.

    Args:
        directory: Directory path to save the transform.

    Note:
        The transform is saved as 'transform.pkl' in the specified directory.
        cloudpickle is used to handle complex objects like lambdas and closures.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("wb") as f:
        cloudpickle.dump(self, f)

load classmethod

load(directory)

Load a transform from a directory.

PARAMETER DESCRIPTION
directory

Directory path containing the saved transform.

TYPE: str | PathLike

RETURNS DESCRIPTION
Self

The loaded transform instance.

RAISES DESCRIPTION
TypeError

If the loaded object is not an instance of this class.

Note

Expects a 'transform.pkl' file in the specified directory.

Source code in src/formed/integrations/ml/transforms/base.py
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
@classmethod
def load(cls, directory: str | PathLike) -> Self:
    """Load a transform from a directory.

    Args:
        directory: Directory path containing the saved transform.

    Returns:
        The loaded transform instance.

    Raises:
        TypeError: If the loaded object is not an instance of this class.

    Note:
        Expects a 'transform.pkl' file in the specified directory.

    """
    filepath = Path(directory) / "transform.pkl"
    with filepath.open("rb") as f:
        obj = cloudpickle.load(f)
    if not isinstance(obj, cls):
        raise TypeError(f"Loaded object is not an instance of {cls.__name__}")
    return obj

formed.integrations.ml.workflow

Workflow steps for machine learning data module integration.

This module provides workflow steps for training data modules and generating instances for machine learning tasks.

Available Steps
  • ml::train_datamodule: Train a data module on a dataset.
  • ml::train_datamodule_with_instances: Train a data module and collect generated instances.
  • ml::generate_instances: Generate instances from a dataset using a data module.
  • ml::generate_instances_without_caching: Generate instances without caching (same as ml::generate_instances but uncached).

DataModuleAndInstances dataclass

DataModuleAndInstances(datamodule, instances)

Bases: Generic[_InputT, _InstanceT]

datamodule instance-attribute

datamodule

instances instance-attribute

instances

DataModuleAndInstancesFormat

Bases: Format[DataModuleAndInstances[_InputT, _InstanceT]], Generic[_InputT, _InstanceT]

identifier property

identifier

Get the unique identifier for this format.

RETURNS DESCRIPTION
str

Format identifier string.

write

write(artifact, directory)
Source code in src/formed/integrations/ml/workflow.py
47
48
49
50
51
52
53
54
55
56
57
58
59
def write(
    self,
    artifact: DataModuleAndInstances[_InputT, _InstanceT],
    directory: Path,
) -> None:
    instances_path = directory / "instances"
    datamodule_path = directory / "datamodule"

    instances_path.mkdir(parents=True, exist_ok=True)
    datamodule_path.mkdir(parents=True, exist_ok=True)

    self._INSTANCES_FORMAT.write(artifact.instances, instances_path)
    self._DATAMODULE_FORMAT.write(artifact.datamodule, datamodule_path)

read

read(directory)
Source code in src/formed/integrations/ml/workflow.py
61
62
63
64
65
66
67
68
def read(self, directory: Path) -> DataModuleAndInstances[_InputT, _InstanceT]:
    instances_path = directory / "instances"
    datamodule_path = directory / "datamodule"

    instances = self._INSTANCES_FORMAT.read(instances_path)
    datamodule = self._DATAMODULE_FORMAT.read(datamodule_path)

    return DataModuleAndInstances(datamodule=datamodule, instances=instances)

is_default_of classmethod

is_default_of(obj)

Check if this format is the default for the given object type.

PARAMETER DESCRIPTION
obj

Object to check.

TYPE: Any

RETURNS DESCRIPTION
bool

True if this format should be used by default for this type.

Source code in src/formed/workflow/format.py
101
102
103
104
105
106
107
108
109
110
111
112
@classmethod
def is_default_of(cls, obj: Any) -> bool:
    """Check if this format is the default for the given object type.

    Args:
        obj: Object to check.

    Returns:
        True if this format should be used by default for this type.

    """
    return False

train_datamodule

train_datamodule(datamodule, dataset)

Train a data module on a dataset.

This step trains a DataModule on the provided dataset, allowing it to learn transformations and build vocabularies.

PARAMETER DESCRIPTION
datamodule

DataModule to train.

TYPE: DataModule[AsConverter, _InputT]

dataset

Training dataset.

TYPE: Iterable[_InputT]

RETURNS DESCRIPTION
DataModule[AsConverter, _InputT]

Trained DataModule.

Source code in src/formed/integrations/ml/workflow.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
@step("ml::train_datamodule", format="json")
def train_datamodule(
    datamodule: DataModule[AsConverter, _InputT],
    dataset: Iterable[_InputT],
) -> DataModule[AsConverter, _InputT]:
    """Train a data module on a dataset.

    This step trains a DataModule on the provided dataset, allowing it to
    learn transformations and build vocabularies.

    Args:
        datamodule: DataModule to train.
        dataset: Training dataset.

    Returns:
        Trained DataModule.
    """
    with datamodule.train(), progress(dataset, desc="Training datamodule") as dataset:
        for example in dataset:
            datamodule(example)
    return datamodule

train_datamodule_with_instances

train_datamodule_with_instances(datamodule, dataset)

Train a data module and collect generated instances.

This step trains a DataModule while collecting all instances generated during training, returning both the trained module and instances.

PARAMETER DESCRIPTION
datamodule

DataModule to train.

TYPE: DataModule[AsConverter, _InputT, _InstanceT]

dataset

Training dataset.

TYPE: Iterable[_InputT]

RETURNS DESCRIPTION
DataModuleAndInstances[_InputT, _InstanceT]

DataModuleAndInstances containing the trained module and generated instances.

Source code in src/formed/integrations/ml/workflow.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
@step("ml::train_datamodule_with_instances", format=DataModuleAndInstancesFormat())
def train_datamodule_with_instances(
    datamodule: DataModule[AsConverter, _InputT, _InstanceT],
    dataset: Iterable[_InputT],
) -> DataModuleAndInstances[_InputT, _InstanceT]:
    """Train a data module and collect generated instances.

    This step trains a DataModule while collecting all instances generated
    during training, returning both the trained module and instances.

    Args:
        datamodule: DataModule to train.
        dataset: Training dataset.

    Returns:
        DataModuleAndInstances containing the trained module and generated instances.
    """

    def generate_instances() -> Iterator[_InstanceT]:
        nonlocal datamodule, dataset

        with datamodule.train(), progress(dataset, desc="Training datamodule") as dataset:
            for example in dataset:
                instance = datamodule(example)
                assert instance is not None
                yield instance

    return DataModuleAndInstances(datamodule=datamodule, instances=generate_instances())

generate_instances

generate_instances(datamodule, dataset)

Generate instances from a dataset using a data module.

This step applies a DataModule to each example in the dataset, generating processed instances.

PARAMETER DESCRIPTION
datamodule

DataModule to use for instance generation.

TYPE: DataModule[AsConverter, _InputT, _InstanceT]

dataset

Input dataset.

TYPE: Iterable[_InputT]

RETURNS DESCRIPTION
Dataset[_InstanceT]

Dataset of generated instances.

Source code in src/formed/integrations/ml/workflow.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
@step("ml::generate_instances", format="dataset")
@step("ml::generate_instances_without_caching", cacheable=False)
def generate_instances(
    datamodule: DataModule[AsConverter, _InputT, _InstanceT],
    dataset: Iterable[_InputT],
) -> Dataset[_InstanceT]:
    """Generate instances from a dataset using a data module.

    This step applies a DataModule to each example in the dataset,
    generating processed instances.

    Args:
        datamodule: DataModule to use for instance generation.
        dataset: Input dataset.

    Returns:
        Dataset of generated instances.
    """

    def generator() -> Iterator[_InstanceT]:
        nonlocal datamodule, dataset
        with progress(dataset, desc="Generating instances") as dataset:
            for example in dataset:
                instance = datamodule(example)
                assert instance is not None
                yield instance

    return Dataset.from_iterable(generator())