Skip to content

Workflow

formed.workflow.WorkflowOrganizer

Bases: Registrable

Source code in formed/workflow/organizer.py
class WorkflowOrganizer(Registrable):
    def __init__(
        self,
        cache: "WorkflowCache",
        callbacks: Optional[Union[WorkflowCallback, Sequence[WorkflowCallback]]],
    ) -> None:
        if isinstance(callbacks, WorkflowCallback):
            callbacks = [callbacks]

        self.cache = cache
        self.callback = MultiWorkflowCallback(callbacks or [])

    def run(
        self,
        executor: "WorkflowExecutor",
        execution: Union[WorkflowGraph, WorkflowExecutionInfo],
    ) -> WorkflowExecutionContext:
        with executor:
            return executor(
                execution,
                cache=self.cache,
                callback=self.callback,
            )

    def get(self, execution_id: WorkflowExecutionID) -> Optional[WorkflowExecutionContext]:
        return None

    def exists(self, execution_id: WorkflowExecutionID) -> bool:
        return self.get(execution_id) is not None

    def remove(self, execution_id: WorkflowExecutionID) -> None:
        pass

formed.workflow.WorkflowExecutor

Bases: Registrable

Source code in formed/workflow/executor.py
class WorkflowExecutor(Registrable):
    def __call__(
        self,
        graph_or_exection: Union[WorkflowGraph, WorkflowExecutionInfo],
        *,
        cache: Optional[WorkflowCache] = None,
        callback: Optional[WorkflowCallback] = None,
    ) -> WorkflowExecutionContext:
        raise NotImplementedError

    def __enter__(self: T_WorkflowExecutor) -> T_WorkflowExecutor:
        return self

    def __exit__(
        self,
        exc_type: Optional[type[BaseException]],
        exc_value: Optional[BaseException],
        traceback: Optional[TracebackType],
    ) -> None:
        pass

formed.workflow.WorkflowGraph

Bases: FromJsonnet

Source code in formed/workflow/graph.py
class WorkflowGraph(FromJsonnet):
    __COLT_BUILDER__ = COLT_BUILDER

    @classmethod
    def _build_step_info(
        cls,
        steps: Mapping[str, Lazy[WorkflowStep]],
    ) -> Mapping[str, WorkflowStepInfo]:
        if not steps:
            return {}

        builder = next(iter(steps.values()))._builder

        def find_dependencies(obj: Any, path: tuple[str, ...]) -> frozenset[tuple[StrictParamPath, str]]:
            refs: set[tuple[StrictParamPath, str]] = set()
            if WorkflowRef.is_ref(builder, obj):
                step_name = str(obj[WORKFLOW_REFKEY])
                refs |= {(path, step_name)}
            if isinstance(obj, Mapping):
                for key, value in obj.items():
                    refs |= find_dependencies(value, path + (key,))
            if isinstance(obj, (list, tuple)):
                for i, value in enumerate(obj):
                    refs |= find_dependencies(value, path + (str(i),))
            return frozenset(refs)

        dependencies = {name: find_dependencies(lazy_step.config, ()) for name, lazy_step in steps.items()}

        stack: set[str] = set()
        visited: set[str] = set()
        sorted_step_names: list[str] = []

        def topological_sort(name: str) -> None:
            if name in stack:
                raise ConfigurationError(f"Cycle detected in workflow dependencies: {name} -> {stack}")
            if name in visited:
                return
            stack.add(name)
            visited.add(name)
            for _, dep_name in dependencies[name]:
                topological_sort(dep_name)
            stack.remove(name)
            sorted_step_names.append(name)

        for name in steps.keys():
            topological_sort(name)

        step_name_to_info: dict[str, WorkflowStepInfo] = {}
        for name in sorted_step_names:
            step = steps[name]
            step_dependencies = frozenset((path, step_name_to_info[dep_name]) for path, dep_name in dependencies[name])
            step_name_to_info[name] = WorkflowStepInfo(name, step, step_dependencies)

        return step_name_to_info

    def __init__(
        self,
        steps: Mapping[str, Lazy[WorkflowStep]],
    ) -> None:
        self._step_info = self._build_step_info(steps)

    def __iter__(self) -> Iterator[WorkflowStepInfo]:
        return iter(self._step_info.values())

    def __getitem__(self, step_name: str) -> WorkflowStepInfo:
        return self._step_info[step_name]

    def get_subgraph(self, step_name: str) -> "WorkflowGraph":
        if step_name not in self._step_info:
            raise ValueError(f"Step {step_name} not found in the graph")
        step_info = self._step_info[step_name]
        subgraph_steps: dict[str, Lazy[WorkflowStep]] = {step_name: step_info.step}
        for _, dependant_step_info in step_info.dependencies:
            for sub_step_info in self.get_subgraph(dependant_step_info.name):
                subgraph_steps[sub_step_info.name] = sub_step_info.step
        return WorkflowGraph(subgraph_steps)

    def visualize(
        self,
        *,
        output: TextIO = sys.stdout,
        additional_info: Mapping[str, str] = {},
    ) -> None:
        def get_node(name: str) -> str:
            if name in additional_info:
                return f"{name}: {additional_info[name]}"
            return name

        dag = DAG(
            {
                get_node(name): {get_node(dep.name) for _, dep in info.dependencies}
                for name, info in self._step_info.items()
            }
        )

        dag.visualize(output=output)

    def to_dict(self) -> dict[str, Any]:
        return {"steps": {step_info.name: step_info.step.config for step_info in self}}

    @classmethod
    def from_config(self, config: WorkflowGraphConfig) -> "WorkflowGraph":
        return self.__COLT_BUILDER__(config, WorkflowGraph)