Skip to content

MLflow

formed.integrations.mlflow.constants

DEFAULT_MLFLOW_EXPERIMENT_NAME module-attribute

DEFAULT_MLFLOW_EXPERIMENT_NAME = 'Default'

DEFAULT_MLFLOW_DIRECTORY module-attribute

DEFAULT_MLFLOW_DIRECTORY = (
    WORKFLOW_INTEGRATION_DIRECTORY / "mlflow"
)

formed.integrations.mlflow.utils

MlflowParamValue module-attribute

MlflowParamValue = Union[int, float, str, None]

MlflowParams module-attribute

MlflowParams = dict[str, MlflowParamValue]

WorkflowRunType

Bases: str, Enum

STEP class-attribute instance-attribute

STEP = 'step'

EXECUTION class-attribute instance-attribute

EXECUTION = 'execution'

WorkflowCacheStatus

Bases: str, Enum

PENDING class-attribute instance-attribute

PENDING = 'pending'

ACTIVE class-attribute instance-attribute

ACTIVE = 'active'

INACTIVE class-attribute instance-attribute

INACTIVE = 'inactive'

MlflowRunStatus

Bases: str, Enum

RUNNING class-attribute instance-attribute

RUNNING = 'RUNNING'

SCHEDULED class-attribute instance-attribute

SCHEDULED = 'SCHEDULED'

FINISHED class-attribute instance-attribute

FINISHED = 'FINISHED'

FAILED class-attribute instance-attribute

FAILED = 'FAILED'

KILLED class-attribute instance-attribute

KILLED = 'KILLED'

from_step_status classmethod

from_step_status(status)
Source code in src/formed/integrations/mlflow/utils.py
59
60
61
62
63
64
65
66
67
68
69
70
71
@classmethod
def from_step_status(cls, status: WorkflowStepStatus) -> "MlflowRunStatus":
    if status == WorkflowStepStatus.PENDING:
        return cls.SCHEDULED
    if status == WorkflowStepStatus.RUNNING:
        return cls.RUNNING
    if status == WorkflowStepStatus.COMPLETED:
        return cls.FINISHED
    if status == WorkflowStepStatus.FAILURE:
        return cls.FAILED
    if status == WorkflowStepStatus.CANCELED:
        return cls.KILLED
    raise ValueError(f"Invalid step status: {status}")

from_execution_status classmethod

from_execution_status(status)
Source code in src/formed/integrations/mlflow/utils.py
73
74
75
76
77
78
79
80
81
82
83
84
85
@classmethod
def from_execution_status(cls, status: WorkflowExecutionStatus) -> "MlflowRunStatus":
    if status == WorkflowExecutionStatus.PENDING:
        return cls.SCHEDULED
    if status == WorkflowExecutionStatus.RUNNING:
        return cls.RUNNING
    if status == WorkflowExecutionStatus.COMPLETED:
        return cls.FINISHED
    if status == WorkflowExecutionStatus.FAILURE:
        return cls.FAILED
    if status == WorkflowExecutionStatus.CANCELED:
        return cls.KILLED
    raise ValueError(f"Invalid execution status: {status}")

to_step_status

to_step_status()
Source code in src/formed/integrations/mlflow/utils.py
87
88
89
90
91
92
93
94
95
96
97
98
def to_step_status(self) -> WorkflowStepStatus:
    if self == self.RUNNING:
        return WorkflowStepStatus.RUNNING
    if self == self.SCHEDULED:
        return WorkflowStepStatus.PENDING
    if self == self.FINISHED:
        return WorkflowStepStatus.COMPLETED
    if self == self.FAILED:
        return WorkflowStepStatus.FAILURE
    if self == self.KILLED:
        return WorkflowStepStatus.CANCELED
    raise ValueError(f"Invalid run status: {self}")

to_execution_status

to_execution_status()
Source code in src/formed/integrations/mlflow/utils.py
100
101
102
103
104
105
106
107
108
109
110
111
def to_execution_status(self) -> WorkflowExecutionStatus:
    if self == self.RUNNING:
        return WorkflowExecutionStatus.RUNNING
    if self == self.SCHEDULED:
        return WorkflowExecutionStatus.PENDING
    if self == self.FINISHED:
        return WorkflowExecutionStatus.COMPLETED
    if self == self.FAILED:
        return WorkflowExecutionStatus.FAILURE
    if self == self.KILLED:
        return WorkflowExecutionStatus.CANCELED
    raise ValueError(f"Invalid run status: {self}")

MlflowTag

Bases: str, Enum

MLFLOW_PARENT_RUN_ID class-attribute instance-attribute

MLFLOW_PARENT_RUN_ID = MLFLOW_PARENT_RUN_ID

MLFLOW_RUN_NAME class-attribute instance-attribute

MLFLOW_RUN_NAME = MLFLOW_RUN_NAME

MLFLOW_RUN_NOTE class-attribute instance-attribute

MLFLOW_RUN_NOTE = MLFLOW_RUN_NOTE

MLFACTORY_RUN_TYPE class-attribute instance-attribute

MLFACTORY_RUN_TYPE = 'formed.workflow.run_type'

MLFACTORY_STEP_FINGERPRINT class-attribute instance-attribute

MLFACTORY_STEP_FINGERPRINT = (
    "formed.workflow.step.fingerprint"
)

MLFACTORY_STEP_CACHE_STATUS class-attribute instance-attribute

MLFACTORY_STEP_CACHE_STATUS = (
    "formed.workflow.step.cache_status"
)

flatten_params

flatten_params(d)
Source code in src/formed/integrations/mlflow/utils.py
126
127
128
129
130
131
132
133
134
135
def flatten_params(d: Mapping[str, JsonValue]) -> MlflowParams:
    result = {}
    for key, value in d.items():
        if isinstance(value, (list, tuple)):
            value = {str(i): v for i, v in enumerate(value)}
        if isinstance(value, dict):
            result.update({f"{key}.{k}": v for k, v in flatten_params(value).items()})
        else:
            result[key] = value
    return result

build_filter_string

build_filter_string(
    run_type=None,
    step_info=None,
    execution_info=None,
    additional_filters=None,
)
Source code in src/formed/integrations/mlflow/utils.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def build_filter_string(
    run_type: Optional[WorkflowRunType] = None,
    step_info: Optional[Union[str, WorkflowStepInfo]] = None,
    execution_info: Optional[Union[str, WorkflowExecutionInfo]] = None,
    additional_filters: Optional[str] = None,
) -> str:
    conditions: list[str] = []
    if run_type is not None:
        conditions.append(f"tags.{MlflowTag.MLFACTORY_RUN_TYPE.value} = '{run_type.value}'")
    if step_info is not None:
        step_fingerprint = step_info.fingerprint if isinstance(step_info, WorkflowStepInfo) else step_info
        conditions.append(f"tags.{MlflowTag.MLFACTORY_STEP_FINGERPRINT.value} = '{step_fingerprint}'")
    if execution_info is not None:
        execution_id = execution_info.id if isinstance(execution_info, WorkflowExecutionInfo) else execution_info
        conditions.append(f"tags.{MlflowTag.MLFLOW_RUN_NAME.value} = '{execution_id}'")
    if additional_filters is not None:
        conditions.append(additional_filters)
    return " AND ".join(conditions)

is_mlflow_using_local_artifact_storage

is_mlflow_using_local_artifact_storage(mlflow_run)
Source code in src/formed/integrations/mlflow/utils.py
158
159
160
161
162
163
def is_mlflow_using_local_artifact_storage(
    mlflow_run: Union[str, MlflowRun],
) -> bool:
    mlflow_run_id = mlflow_run.info.run_id if isinstance(mlflow_run, MlflowRun) else mlflow_run
    mlflow_artifact_uri = urlparse(artifact_utils.get_artifact_uri(run_id=mlflow_run_id))  # type: ignore[no-untyped-call]
    return bool(mlflow_artifact_uri.scheme == "file")

get_mlflow_local_artifact_storage_path

get_mlflow_local_artifact_storage_path(mlflow_run)
Source code in src/formed/integrations/mlflow/utils.py
166
167
168
169
170
171
172
173
def get_mlflow_local_artifact_storage_path(
    mlflow_run: Union[str, MlflowRun],
) -> Optional[Path]:
    mlflow_run_id = mlflow_run.info.run_id if isinstance(mlflow_run, MlflowRun) else mlflow_run
    mlflow_artifact_uri = urlparse(artifact_utils.get_artifact_uri(run_id=mlflow_run_id))  # type: ignore[no-untyped-call]
    if mlflow_artifact_uri.scheme == "file":
        return Path(mlflow_artifact_uri.path)
    return None

fetch_child_mlflow_runs

fetch_child_mlflow_runs(client, experiment, mlflow_run)
Source code in src/formed/integrations/mlflow/utils.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def fetch_child_mlflow_runs(
    client: MlflowClient,
    experiment: Union[str, MlflowExperiment],
    mlflow_run: Union[str, MlflowRun],
) -> Iterator[MlflowRun]:
    experiment = get_mlflow_experiment(experiment)
    mlflow_run_id = mlflow_run.info.run_id if isinstance(mlflow_run, MlflowRun) else mlflow_run
    page_token: Optional[str] = None
    while True:
        runs = client.search_runs(
            experiment_ids=[experiment.experiment_id],
            filter_string=f"tags.{MlflowTag.MLFLOW_PARENT_RUN_ID.value} = '{mlflow_run_id}'",
            page_token=page_token,
        )
        yield from iter(runs)
        if runs.token is None:
            break
        page_token = runs.token

update_mlflow_tags

update_mlflow_tags(client, run, tags)
Source code in src/formed/integrations/mlflow/utils.py
196
197
198
199
200
201
202
203
def update_mlflow_tags(
    client: MlflowClient,
    run: Union[str, MlflowRun],
    tags: dict[MlflowTag, str],
) -> None:
    run_id = run.info.run_id if isinstance(run, MlflowRun) else run
    for tag, value in tags.items():
        client.set_tag(run_id=run_id, key=tag.value, value=value)

fetch_mlflow_runs

fetch_mlflow_runs(
    client,
    experiment,
    *,
    run_type=None,
    step_info=None,
    execution_info=None,
    with_children=False,
)
Source code in src/formed/integrations/mlflow/utils.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def fetch_mlflow_runs(
    client: MlflowClient,
    experiment: Union[str, MlflowExperiment],
    *,
    run_type: Optional[WorkflowRunType] = None,
    step_info: Optional[Union[str, WorkflowStepInfo]] = None,
    execution_info: Optional[Union[str, WorkflowExecutionInfo]] = None,
    with_children: bool = False,
) -> Iterator[MlflowRun]:
    experiment = get_mlflow_experiment(experiment)

    page_token: Optional[str] = None
    filter_string = build_filter_string(run_type, step_info, execution_info)
    while True:
        runs = client.search_runs(
            experiment_ids=[experiment.experiment_id],
            filter_string=filter_string,
            page_token=page_token,
        )
        for run in runs:
            yield run
            if with_children:
                yield from fetch_child_mlflow_runs(client, experiment, run)
        if runs.token is None:
            break
        page_token = runs.token

fetch_mlflow_run

fetch_mlflow_run(
    client,
    experiment,
    *,
    run_type=None,
    step_info=None,
    execution_info=None,
)
Source code in src/formed/integrations/mlflow/utils.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def fetch_mlflow_run(
    client: MlflowClient,
    experiment: Union[str, MlflowExperiment],
    *,
    run_type: Optional[WorkflowRunType] = None,
    step_info: Optional[Union[str, WorkflowStepInfo]] = None,
    execution_info: Optional[Union[str, WorkflowExecutionInfo]] = None,
) -> Optional[MlflowRun]:
    return next(
        fetch_mlflow_runs(
            client,
            experiment,
            run_type=run_type,
            step_info=step_info,
            execution_info=execution_info,
        ),
        None,
    )

generate_new_execution_id

generate_new_execution_id(client, experiment)
Source code in src/formed/integrations/mlflow/utils.py
254
255
256
257
258
259
def generate_new_execution_id(client: MlflowClient, experiment: Union[str, MlflowExperiment]) -> WorkflowExecutionID:
    experiment = get_mlflow_experiment(experiment)
    while True:
        execution_id = uuid.uuid4().hex[:8]
        if fetch_mlflow_run(client, experiment, execution_info=execution_id) is None:
            return WorkflowExecutionID(execution_id)

add_mlflow_run

add_mlflow_run(
    client,
    experiment,
    step_or_execution_info,
    parent_run_id=None,
)
Source code in src/formed/integrations/mlflow/utils.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def add_mlflow_run(
    client: MlflowClient,
    experiment: Union[str, MlflowExperiment],
    step_or_execution_info: Union[WorkflowStepInfo, WorkflowExecutionInfo],
    parent_run_id: Optional[str] = None,
) -> MlflowRun:
    experiment = get_mlflow_experiment(experiment)

    run_name: str
    params: MlflowParams
    tags: dict[MlflowTag, str]

    if isinstance(step_or_execution_info, WorkflowStepInfo):
        run_name = step_or_execution_info.name
        params = get_step_params(step_or_execution_info)
        tags = get_step_tags(step_or_execution_info)
    elif isinstance(step_or_execution_info, WorkflowExecutionInfo):
        assert step_or_execution_info.id is not None
        run_name = step_or_execution_info.id
        params = get_execution_params(step_or_execution_info)
        tags = get_execution_tags(step_or_execution_info)
    else:
        raise ValueError(f"Unsupported type: {type(step_or_execution_info)}")

    if parent_run_id is not None:
        tags[MlflowTag.MLFLOW_PARENT_RUN_ID] = parent_run_id

    run = client.create_run(
        experiment_id=experiment.experiment_id,
        run_name=run_name,
        tags=context_registry.resolve_tags({tag.value: value for tag, value in tags.items()}),
    )
    for key, value in params.items():
        client.log_param(run.info.run_id, key, value)

    return run

download_mlflow_artifacts

download_mlflow_artifacts(
    client,
    experiment,
    step_or_execution_info,
    directory,
    artifact_path=None,
)
Source code in src/formed/integrations/mlflow/utils.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
def download_mlflow_artifacts(
    client: MlflowClient,
    experiment: Union[str, MlflowExperiment],
    step_or_execution_info: Union[WorkflowStepInfo, WorkflowExecutionInfo],
    directory: Union[str, PathLike],
    artifact_path: Optional[str] = None,
) -> None:
    run_type = (
        WorkflowRunType.STEP if isinstance(step_or_execution_info, WorkflowStepInfo) else WorkflowRunType.EXECUTION
    )
    step_info = step_or_execution_info if isinstance(step_or_execution_info, WorkflowStepInfo) else None
    execution_info = step_or_execution_info if isinstance(step_or_execution_info, WorkflowExecutionInfo) else None
    run = fetch_mlflow_run(
        client,
        experiment,
        run_type=run_type,
        step_info=step_info,
        execution_info=execution_info,
    )
    if run is None:
        raise FileNotFoundError("Run not found")
    directory = Path(directory)
    with tempfile.TemporaryDirectory() as temp_dir:
        mlflow.artifacts.download_artifacts(
            run_id=run.info.run_id,
            artifact_path=artifact_path,
            dst_path=temp_dir,
        )
        download_path = Path(temp_dir)
        if artifact_path is not None:
            download_path = download_path / artifact_path
        directory.mkdir(parents=True, exist_ok=True)
        for path in download_path.glob("*"):
            shutil.move(path, directory / path.name)

upload_mlflow_artifacts

upload_mlflow_artifacts(
    client,
    experiment,
    step_or_execution_info,
    directory,
    artifact_path=None,
)
Source code in src/formed/integrations/mlflow/utils.py
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
def upload_mlflow_artifacts(
    client: MlflowClient,
    experiment: Union[str, MlflowExperiment],
    step_or_execution_info: Union[WorkflowStepInfo, WorkflowExecutionInfo],
    directory: Union[str, PathLike],
    artifact_path: Optional[str] = None,
) -> None:
    run = fetch_mlflow_run(
        client,
        experiment,
        step_info=step_or_execution_info if isinstance(step_or_execution_info, WorkflowStepInfo) else None,
        execution_info=step_or_execution_info if isinstance(step_or_execution_info, WorkflowExecutionInfo) else None,
    )
    if run is None:
        raise ValueError("Run not found")
    client.log_artifacts(
        run_id=run.info.run_id,
        local_dir=str(directory),
        artifact_path=artifact_path,
    )

terminate_mlflow_run

terminate_mlflow_run(
    client, experiment, step_or_execution_state
)
Source code in src/formed/integrations/mlflow/utils.py
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
def terminate_mlflow_run(
    client: MlflowClient,
    experiment: Union[str, MlflowExperiment],
    step_or_execution_state: Union[WorkflowStepState, WorkflowExecutionState],
) -> None:
    step_info: Optional[str] = None
    execution_info: Optional[str] = None
    run_status: MlflowRunStatus
    end_time: Optional[int] = None
    if isinstance(step_or_execution_state, WorkflowStepState):
        if step_or_execution_state.status not in (
            WorkflowStepStatus.COMPLETED,
            WorkflowStepStatus.FAILURE,
            WorkflowStepStatus.CANCELED,
        ):
            raise ValueError(f"Invalid step status: {step_or_execution_state.status}")
        step_info = step_or_execution_state.fingerprint
        run_status = MlflowRunStatus.from_step_status(step_or_execution_state.status)
        end_time = (
            int(step_or_execution_state.finished_at.timestamp() * 1000) if step_or_execution_state.finished_at else None
        )
    elif isinstance(step_or_execution_state, WorkflowExecutionState):
        if step_or_execution_state.status not in (
            WorkflowExecutionStatus.COMPLETED,
            WorkflowExecutionStatus.FAILURE,
            WorkflowExecutionStatus.CANCELED,
        ):
            raise ValueError(f"Invalid execution status: {step_or_execution_state.status}")
        execution_info = step_or_execution_state.execution_id
        run_status = MlflowRunStatus.from_execution_status(step_or_execution_state.status)
        end_time = (
            int(step_or_execution_state.finished_at.timestamp() * 1000) if step_or_execution_state.finished_at else None
        )
    experiment = get_mlflow_experiment(experiment)
    run = fetch_mlflow_run(
        client,
        experiment,
        step_info=step_info,
        execution_info=execution_info,
    )
    if run is None:
        raise ValueError("Run not found")
    client.set_terminated(
        run_id=run.info.run_id,
        status=run_status.value,
        end_time=end_time,
    )

remove_mlflow_run

remove_mlflow_run(
    client, experiment, step_or_execution_info
)
Source code in src/formed/integrations/mlflow/utils.py
407
408
409
410
411
412
413
414
415
416
417
418
419
420
def remove_mlflow_run(
    client: MlflowClient,
    experiment: Union[str, MlflowExperiment],
    step_or_execution_info: Union[WorkflowStepInfo, WorkflowExecutionInfo],
) -> None:
    run = fetch_mlflow_run(
        client,
        experiment,
        step_info=step_or_execution_info if isinstance(step_or_execution_info, WorkflowStepInfo) else None,
        execution_info=step_or_execution_info if isinstance(step_or_execution_info, WorkflowExecutionInfo) else None,
    )
    if run is None:
        raise ValueError("Run not found")
    client.delete_run(run.info.run_id)

get_mlflow_experiment

get_mlflow_experiment(experiment)
Source code in src/formed/integrations/mlflow/utils.py
423
424
425
426
427
428
429
430
431
432
def get_mlflow_experiment(experiment: Union[str, MlflowExperiment]) -> MlflowExperiment:
    if isinstance(experiment, str):
        client = MlflowClient()
        experiment_or_none = client.get_experiment_by_name(experiment)
        if experiment_or_none is None:
            mlflow.create_experiment(experiment)
            experiment_or_none = client.get_experiment_by_name(experiment)
            assert experiment_or_none is not None
        experiment = experiment_or_none
    return experiment

get_step_params

get_step_params(step_info)
Source code in src/formed/integrations/mlflow/utils.py
435
436
437
438
def get_step_params(step_info: WorkflowStepInfo) -> MlflowParams:
    config = as_jsonvalue(step_info.step.config)
    assert isinstance(config, dict)
    return flatten_params(config)

get_execution_params

get_execution_params(execution_info)
Source code in src/formed/integrations/mlflow/utils.py
441
442
443
444
def get_execution_params(execution_info: WorkflowExecutionInfo) -> MlflowParams:
    config = execution_info.graph.json()
    assert isinstance(config, dict)
    return flatten_params(config)

get_step_tags

get_step_tags(step_info)
Source code in src/formed/integrations/mlflow/utils.py
447
448
449
450
451
def get_step_tags(step_info: WorkflowStepInfo) -> dict[MlflowTag, str]:
    return {
        MlflowTag.MLFACTORY_RUN_TYPE: WorkflowRunType.STEP.value,
        MlflowTag.MLFACTORY_STEP_FINGERPRINT: step_info.fingerprint,
    }

get_execution_tags

get_execution_tags(execution_info)
Source code in src/formed/integrations/mlflow/utils.py
454
455
456
457
458
459
def get_execution_tags(execution_info: WorkflowExecutionInfo) -> dict[MlflowTag, str]:
    assert execution_info.id is not None
    return {
        MlflowTag.MLFLOW_RUN_NAME: execution_info.id,
        MlflowTag.MLFACTORY_RUN_TYPE: WorkflowRunType.EXECUTION.value,
    }

get_mlflow_tags_from_run

get_mlflow_tags_from_run(run)
Source code in src/formed/integrations/mlflow/utils.py
462
463
def get_mlflow_tags_from_run(run: MlflowRun) -> dict[MlflowTag, str]:
    return {tag: run.data.tags[tag.value] for tag in MlflowTag if tag.value in run.data.tags}

get_execution_state_from_run

get_execution_state_from_run(run)
Source code in src/formed/integrations/mlflow/utils.py
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
def get_execution_state_from_run(run: MlflowRun) -> WorkflowExecutionState:
    tags = get_mlflow_tags_from_run(run)
    if MlflowTag.MLFACTORY_RUN_TYPE not in tags:
        raise ValueError("Run type not found")
    if tags[MlflowTag.MLFACTORY_RUN_TYPE] != WorkflowRunType.EXECUTION.value:
        raise ValueError(f"Invalid run type: {tags[MlflowTag.MLFACTORY_RUN_TYPE]}")
    if run.info.run_name is None:
        raise ValueError("Run name not found")
    execution_id = WorkflowExecutionID(run.info.run_name)
    status = MlflowRunStatus(run.info.status).to_execution_status()
    started_at = datetime.datetime.fromtimestamp(run.info.start_time / 1000)
    finished_at = datetime.datetime.fromtimestamp(run.info.end_time / 1000) if run.info.end_time else None
    return WorkflowExecutionState(
        execution_id=execution_id,
        status=status,
        started_at=started_at,
        finished_at=finished_at,
    )

formed.integrations.mlflow.workflow

logger module-attribute

logger = getLogger(__name__)

T module-attribute

T = TypeVar('T')

MlflowWorkflowCache

MlflowWorkflowCache(
    experiment_name=DEFAULT_MLFLOW_EXPERIMENT_NAME,
    directory=None,
    mlflow_client=None,
)

Bases: WorkflowCache

Source code in src/formed/integrations/mlflow/workflow.py
75
76
77
78
79
80
81
82
83
84
def __init__(
    self,
    experiment_name: str = DEFAULT_MLFLOW_EXPERIMENT_NAME,
    directory: Optional[Union[str, PathLike]] = None,
    mlflow_client: Optional[MlflowClient] = None,
) -> None:
    self._client = mlflow_client or MlflowClient()
    self._experiment_name = experiment_name
    self._directory = Path(directory or self._DEFAULT_DIRECTORY)
    self._directory.mkdir(parents=True, exist_ok=True)

MlflowWorkflowCallback

MlflowWorkflowCallback(
    experiment_name=DEFAULT_MLFLOW_EXPERIMENT_NAME,
    mlflow_client=None,
    log_execution_metrics=False,
)

Bases: WorkflowCallback

Source code in src/formed/integrations/mlflow/workflow.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def __init__(
    self,
    experiment_name: str = DEFAULT_MLFLOW_EXPERIMENT_NAME,
    mlflow_client: Optional[MlflowClient] = None,
    log_execution_metrics: bool = False,
) -> None:
    self._client = mlflow_client or MlflowClient()
    self._experiment_name = experiment_name
    self._execution_run: Optional[MlflowRun] = None
    self._execution_log: Optional[LogCapture[StringIO]] = None
    self._step_log: dict[WorkflowStepInfo, LogCapture[StringIO]] = {}
    self._log_execution_metrics = log_execution_metrics
    self._step_run_ids: dict[str, str] = {}
    self._dependents_map: dict[str, set[str]] = {}

on_execution_start

on_execution_start(execution_context)
Source code in src/formed/integrations/mlflow/workflow.py
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
def on_execution_start(
    self,
    execution_context: WorkflowExecutionContext,
) -> None:
    assert self._execution_run is None
    execution_info = execution_context.info
    if execution_info.id is None:
        execution_info.id = mlflow_utils.generate_new_execution_id(self._client, self._experiment_name)
    self._execution_log = LogCapture(StringIO())
    self._execution_log.start()
    self._execution_run = mlflow_utils.add_mlflow_run(
        self._client,
        self._experiment_name,
        execution_info,
    )
    # Use WorkflowExecutionInfo.to_json_dict() for proper serialization
    self._client.log_dict(
        run_id=self._execution_run.info.run_id,
        dictionary=execution_info.json(),
        artifact_file=self._EXECUTION_METADATA_ARTIFACT_FILENAME,
    )

    # Initialize tracking for notes
    self._dependents_map = self._build_dependents_map(execution_info.graph)
    self._step_run_ids = {}

    # Set initial execution note
    initial_note = self._generate_execution_note_markdown(
        execution_info,
        self._step_run_ids,
        self._dependents_map,
    )
    self._update_run_note(self._execution_run.info.run_id, initial_note)

on_execution_end

on_execution_end(execution_context)
Source code in src/formed/integrations/mlflow/workflow.py
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def on_execution_end(
    self,
    execution_context: "WorkflowExecutionContext",
) -> None:
    assert self._execution_run is not None
    mlflow_utils.terminate_mlflow_run(
        self._client,
        self._experiment_name,
        execution_context.state,
    )
    if self._execution_log is not None:
        self._execution_log.stop()
        self._client.log_text(
            run_id=self._execution_run.info.run_id,
            text=self._execution_log.stream.getvalue(),
            artifact_file=self._LOG_FILENAME,
        )
        self._execution_log.stream.close()
    self._execution_run = None

on_step_start

on_step_start(step_context, execution_context)
Source code in src/formed/integrations/mlflow/workflow.py
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
def on_step_start(
    self,
    step_context: WorkflowStepContext,
    execution_context: WorkflowExecutionContext,
) -> None:
    assert self._execution_run is not None
    step_info = step_context.info
    run = mlflow_utils.add_mlflow_run(
        self._client,
        self._experiment_name,
        step_info,
        parent_run_id=self._execution_run.info.run_id,
    )
    self._client.log_dict(
        run_id=run.info.run_id,
        dictionary=step_info.json(),
        artifact_file=self._STEP_METADATA_ARTIFACT_FILENAME,
    )
    self._step_log[step_info] = LogCapture(StringIO(), logger=get_step_logger_from_info(step_info))
    self._step_log[step_info].start()

    # Store step run ID
    self._step_run_ids[step_info.name] = run.info.run_id

    # Set step note
    step_note = self._generate_step_note_markdown(
        step_info,
        execution_context.info,
        self._step_run_ids,
        self._dependents_map,
    )
    self._update_run_note(run.info.run_id, step_note)

    # Update execution note with new run ID
    execution_note = self._generate_execution_note_markdown(
        execution_context.info,
        self._step_run_ids,
        self._dependents_map,
    )
    self._update_run_note(self._execution_run.info.run_id, execution_note)

on_step_end

on_step_end(step_context, execution_context)
Source code in src/formed/integrations/mlflow/workflow.py
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
def on_step_end(
    self,
    step_context: WorkflowStepContext,
    execution_context: WorkflowExecutionContext,
) -> None:
    step_info = step_context.info
    mlflow_utils.terminate_mlflow_run(
        self._client,
        self._experiment_name,
        step_context.state,
    )
    if (step_log := self._step_log.pop(step_info, None)) is not None:
        run = mlflow_utils.fetch_mlflow_run(
            self._client,
            self._experiment_name,
            step_info=step_info,
        )
        if run is None:
            raise RuntimeError(f"Run for step {step_info} not found")
        step_log.stop()
        self._client.log_text(
            run_id=run.info.run_id,
            text=step_log.stream.getvalue(),
            artifact_file=self._LOG_FILENAME,
        )
        if (
            WorkflowStepResultFlag.METRICS in WorkflowStepResultFlag.get_flags(step_info)
            and step_context.state.status == WorkflowStepStatus.COMPLETED
        ):
            metrics = execution_context.cache[step_info]
            assert isinstance(metrics, dict), f"Expected dict, got {type(metrics)}"
            for key, value in metrics.items():
                self._client.log_metric(run.info.run_id, key, value)
            if self._log_execution_metrics:
                assert self._execution_run is not None
                for key, value in metrics.items():
                    key = f"{step_info.name}/{key}"
                    self._client.log_metric(self._execution_run.info.run_id, key, value)
        step_log.stream.close()

MlflowWorkflowOrganizer

MlflowWorkflowOrganizer(
    experiment_name=DEFAULT_MLFLOW_EXPERIMENT_NAME,
    cache=None,
    callbacks=None,
    log_execution_metrics=None,
)

Bases: WorkflowOrganizer

Source code in src/formed/integrations/mlflow/workflow.py
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
def __init__(
    self,
    experiment_name: str = DEFAULT_MLFLOW_EXPERIMENT_NAME,
    cache: Optional[WorkflowCache] = None,
    callbacks: Optional[Union[WorkflowCallback, Sequence[WorkflowCallback]]] = None,
    log_execution_metrics: Optional[bool] = None,
) -> None:
    self._client = MlflowClient()
    self._experiment_name = experiment_name

    cache = cache or MlflowWorkflowCache(
        experiment_name=experiment_name,
        mlflow_client=self._client,
    )
    if callbacks is None:
        callbacks = []
    elif isinstance(callbacks, WorkflowCallback):
        callbacks = [callbacks]
    if any(isinstance(callback, MlflowWorkflowCallback) for callback in callbacks):
        if log_execution_metrics is not None:
            logger.warning(
                "Ignoring `log_execution_metrics` parameter because `MlflowWorkflowCallback` is already present"
            )
    else:
        mlflow_callback = MlflowWorkflowCallback(
            experiment_name,
            mlflow_client=self._client,
            log_execution_metrics=log_execution_metrics or False,
        )
        callbacks = [mlflow_callback] + list(callbacks)

    super().__init__(cache, callbacks)

cache instance-attribute

cache = cache

callback instance-attribute

callback = MultiWorkflowCallback(callbacks or [])

run

run(executor, execution)
Source code in src/formed/integrations/mlflow/workflow.py
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
def run(
    self,
    executor: WorkflowExecutor,
    execution: Union[WorkflowGraph, WorkflowExecutionInfo],
) -> WorkflowExecutionContext:
    cxt = contextvars.copy_context()

    super_run = super().run

    def _run() -> WorkflowExecutionContext:
        experiment = get_mlflow_experiment(self._experiment_name)
        _MLFLOW_EXPERIMENT.set(experiment)
        return super_run(executor, execution)

    return cxt.run(_run)

get

get(execution_id)
Source code in src/formed/integrations/mlflow/workflow.py
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
def get(self, execution_id: WorkflowExecutionID) -> Optional[WorkflowExecutionContext]:
    run = mlflow_utils.fetch_mlflow_run(
        self._client,
        self._experiment_name,
        execution_info=execution_id,
    )
    if run is None:
        return None
    artifact_uri = run.info.artifact_uri
    if not artifact_uri:
        raise RuntimeError(f"Run {run.info.run_id} has no artifact URI")

    # Load execution data using proper deserialization
    # Download artifact to temporary file
    execution_data = mlflow.artifacts.load_dict(
        artifact_uri + "/" + MlflowWorkflowCallback._EXECUTION_METADATA_ARTIFACT_FILENAME
    )
    execution_info = WorkflowExecutionInfo.from_json(execution_data)

    execution_state = mlflow_utils.get_execution_state_from_run(run)
    return WorkflowExecutionContext(execution_info, execution_state, self.cache, self.callback)

exists

exists(execution_id)
Source code in src/formed/integrations/mlflow/workflow.py
585
586
587
588
589
590
591
def exists(self, execution_id: WorkflowExecutionID) -> bool:
    run = mlflow_utils.fetch_mlflow_run(
        self._client,
        self._experiment_name,
        execution_info=execution_id,
    )
    return run is not None

remove

remove(execution_id)
Source code in src/formed/integrations/mlflow/workflow.py
593
594
595
596
597
598
599
600
601
def remove(self, execution_id: WorkflowExecutionID) -> None:
    for run in mlflow_utils.fetch_mlflow_runs(
        self._client,
        self._experiment_name,
        execution_info=execution_id,
        with_children=True,
    ):
        logger.info(f"Removing run {run.info.run_id}")
        self._client.delete_run(run.info.run_id)

MlflowLogger

MlflowLogger(run)
Source code in src/formed/integrations/mlflow/workflow.py
607
608
def __init__(self, run: MlflowRun):
    self.run = run

run instance-attribute

run = run

mlflow_client property

mlflow_client

log_metric

log_metric(key, value, timestamp=None, step=None)
Source code in src/formed/integrations/mlflow/workflow.py
625
626
627
628
629
630
631
632
633
634
635
636
637
638
def log_metric(
    self,
    key: str,
    value: float,
    timestamp: Optional[int] = None,
    step: Optional[int] = None,
) -> None:
    self.mlflow_client.log_metric(
        run_id=self.run.info.run_id,
        key=key,
        value=value,
        timestamp=timestamp,
        step=step,
    )

log_metrics

log_metrics(metrics)
Source code in src/formed/integrations/mlflow/workflow.py
640
641
642
def log_metrics(self, metrics: dict[str, float]) -> None:
    for key, value in metrics.items():
        self.log_metric(key, value)

log_table

log_table(data, artifact_path)
Source code in src/formed/integrations/mlflow/workflow.py
644
645
646
647
648
649
650
651
652
653
def log_table(
    self,
    data: Union[dict[str, Sequence[Union[str, bool, int, float]]], "PandasDataFrame"],
    artifact_path: str,
) -> None:
    self.mlflow_client.log_table(
        run_id=self.run.info.run_id,
        data=data,
        artifact_file=self._get_artifact_path(artifact_path),
    )

log_text

log_text(text, artifact_path)
Source code in src/formed/integrations/mlflow/workflow.py
655
656
657
658
659
660
661
662
663
664
def log_text(
    self,
    text: str,
    artifact_path: str,
) -> None:
    self.mlflow_client.log_text(
        run_id=self.run.info.run_id,
        text=text,
        artifact_file=self._get_artifact_path(artifact_path),
    )

log_dict

log_dict(dictionary, artifact_path)
Source code in src/formed/integrations/mlflow/workflow.py
666
667
668
669
670
671
672
673
674
675
def log_dict(
    self,
    dictionary: dict[str, JsonValue],
    artifact_path: str,
) -> None:
    self.mlflow_client.log_dict(
        run_id=self.run.info.run_id,
        dictionary=dictionary,
        artifact_file=self._get_artifact_path(artifact_path),
    )

log_figure

log_figure(figure, artifact_path)
Source code in src/formed/integrations/mlflow/workflow.py
677
678
679
680
681
682
683
684
685
686
def log_figure(
    self,
    figure: Union["MatplotlibFigure", "PlotlyFigure"],
    artifact_path: str,
) -> None:
    self.mlflow_client.log_figure(
        run_id=self.run.info.run_id,
        figure=figure,
        artifact_file=self._get_artifact_path(artifact_path),
    )

log_image

log_image(image, artifact_path=None)
Source code in src/formed/integrations/mlflow/workflow.py
688
689
690
691
692
693
694
695
696
697
def log_image(
    self,
    image: Union["NumpyArray", "PILImage", "MlflowImage"],
    artifact_path: Optional[str] = None,
) -> None:
    self.mlflow_client.log_image(
        run_id=self.run.info.run_id,
        image=image,
        artifact_file=self._get_artifact_path(artifact_path),
    )

log_artifact

log_artifact(local_path, artifact_path=None)
Source code in src/formed/integrations/mlflow/workflow.py
699
700
701
702
703
704
705
706
707
708
def log_artifact(
    self,
    local_path: Union[str, PathLike],
    artifact_path: Optional[str] = None,
) -> None:
    self.mlflow_client.log_artifact(
        run_id=self.run.info.run_id,
        local_path=local_path,
        artifact_path=self._get_artifact_path(artifact_path),
    )

log_artifacts

log_artifacts(local_dir, artifact_path=None)
Source code in src/formed/integrations/mlflow/workflow.py
710
711
712
713
714
715
716
717
718
719
def log_artifacts(
    self,
    local_dir: Union[str, PathLike],
    artifact_path: Optional[str] = None,
) -> None:
    self.mlflow_client.log_artifacts(
        run_id=self.run.info.run_id,
        local_dir=str(local_dir),
        artifact_path=self._get_artifact_path(artifact_path),
    )

use_mlflow_experiment

use_mlflow_experiment()
Source code in src/formed/integrations/mlflow/workflow.py
722
723
def use_mlflow_experiment() -> Optional[MlflowExperiment]:
    return _MLFLOW_EXPERIMENT.get()

use_mlflow_logger

use_mlflow_logger()
Source code in src/formed/integrations/mlflow/workflow.py
726
727
728
729
730
731
732
733
734
735
736
737
def use_mlflow_logger() -> Optional[MlflowLogger]:
    if (experiment := use_mlflow_experiment()) is None:
        return None

    if (context := use_step_context()) is None:
        return None

    client = MlflowClient()
    if (run := fetch_mlflow_run(client, experiment, step_info=context.info)) is None:
        return None

    return MlflowLogger(run)