69 lines
2.0 KiB
Python
Raw Normal View History

2023-12-06 17:26:39 +08:00
from abc import ABC
2023-12-08 11:25:26 +08:00
from typing import Union, Type, List
2023-12-06 17:26:39 +08:00
2023-12-08 11:25:26 +08:00
import networkx as nx
2023-12-06 17:26:39 +08:00
2023-12-08 11:25:26 +08:00
from knext.common.restable import RESTable
from knext.common.runnable import Runnable
2023-12-06 17:26:39 +08:00
2023-12-08 11:25:26 +08:00
class Chain(Runnable, RESTable):
dag: nx.DiGraph
def submit(self):
2023-12-06 17:26:39 +08:00
pass
2023-12-08 11:25:26 +08:00
def to_rest(self):
pass
2023-12-11 15:34:02 +08:00
def __rshift__(
self,
other: Union[
Type["Chain"],
List[Type["Chain"]],
Type["Component"],
List[Type["Component"]],
None,
],
):
2023-12-08 11:25:26 +08:00
from knext.component.base import Component
2023-12-11 15:34:02 +08:00
2023-12-08 11:25:26 +08:00
if not other:
return self
if not isinstance(other, list):
other = [other]
dag_list = []
for o in other:
if not o:
dag_list.append(o.dag)
if isinstance(o, Component):
2023-12-11 15:34:02 +08:00
end_nodes = [
node
for node, out_degree in self.dag.out_degree()
if out_degree == 0 or node.last
]
2023-12-08 11:25:26 +08:00
dag = nx.DiGraph(self.dag)
if len(end_nodes) > 0:
for end_node in end_nodes:
dag.add_edge(end_node, o)
dag.add_node(o)
dag_list.append(dag)
elif isinstance(o, Chain):
combined_dag = nx.compose(self.dag, o.dag)
2023-12-11 15:34:02 +08:00
end_nodes = [
node
for node, out_degree in self.dag.out_degree()
if out_degree == 0 or node.last
]
start_nodes = [
node for node, in_degree in o.dag.in_degree() if in_degree == 0
]
2023-12-08 11:25:26 +08:00
if len(end_nodes) > 0 and len(start_nodes) > 0:
for end_node in end_nodes:
for start_node in start_nodes:
combined_dag.add_edge(end_node, start_node)
final_dag = nx.compose_all(dag_list)
return Chain(dag=final_dag)