from abc import ABC from typing import Union, Type, List import networkx as nx from knext.common.restable import RESTable from knext.common.runnable import Runnable class Chain(Runnable, RESTable): dag: nx.DiGraph def submit(self): pass def to_rest(self): pass def __rshift__( self, other: Union[ Type["Chain"], List[Type["Chain"]], Type["Component"], List[Type["Component"]], None, ], ): from knext.component.base import Component if not other: return self if not isinstance(other, list): other = [other] dag_list = [] for o in other: if not o: dag_list.append(o.dag) if isinstance(o, Component): end_nodes = [ node for node, out_degree in self.dag.out_degree() if out_degree == 0 or node.last ] dag = nx.DiGraph(self.dag) if len(end_nodes) > 0: for end_node in end_nodes: dag.add_edge(end_node, o) dag.add_node(o) dag_list.append(dag) elif isinstance(o, Chain): combined_dag = nx.compose(self.dag, o.dag) end_nodes = [ node for node, out_degree in self.dag.out_degree() if out_degree == 0 or node.last ] start_nodes = [ node for node, in_degree in o.dag.in_degree() if in_degree == 0 ] if len(end_nodes) > 0 and len(start_nodes) > 0: for end_node in end_nodes: for start_node in start_nodes: combined_dag.add_edge(end_node, start_node) final_dag = nx.compose_all(dag_list) return Chain(dag=final_dag)