110 lines
2.9 KiB
Python
Raw Normal View History

2023-12-06 17:26:39 +08:00
# -*- coding: utf-8 -*-
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
2023-12-08 11:25:26 +08:00
from abc import ABC
2023-12-06 17:26:39 +08:00
from enum import Enum
2023-12-11 10:44:37 +08:00
from typing import List, Union, Type
2023-12-06 17:26:39 +08:00
2023-12-08 11:25:26 +08:00
import networkx as nx
2023-12-06 17:26:39 +08:00
2023-12-08 11:25:26 +08:00
from knext.common.restable import RESTable
from knext.common.runnable import Runnable
2023-12-06 17:26:39 +08:00
class ComponentTypeEnum(str, Enum):
Builder = "BUILDER"
2023-12-15 17:33:54 +08:00
Reasoner = "REASONER"
GraphLearning = "GRAPH_LEARNING"
2023-12-06 17:26:39 +08:00
2023-12-08 11:25:26 +08:00
class Component(Runnable, RESTable, ABC):
2023-12-06 17:26:39 +08:00
"""
Base class for all component.
"""
2023-12-15 17:33:54 +08:00
@property
2023-12-06 17:26:39 +08:00
def id(self):
return str(id(self))
@property
def type(self):
return
@property
def label(self):
return
@property
def name(self):
2023-12-11 10:44:37 +08:00
return self.__class__.__name__
2023-12-06 17:26:39 +08:00
def to_dict(self):
2023-12-15 17:33:54 +08:00
return {"id": self.id, "name": self.name}
2023-12-06 17:26:39 +08:00
2023-12-08 11:25:26 +08:00
def __hash__(self):
return id(self)
def __eq__(self, other):
return hash(self) == hash(other)
2023-12-11 15:34:02 +08:00
def __rshift__(
self,
other: Union[
Type["Chain"],
List[Type["Chain"]],
Type["Component"],
List[Type["Component"]],
None,
],
):
2023-12-08 11:25:26 +08:00
from knext.chain.base import Chain
2023-12-11 15:34:02 +08:00
2023-12-08 11:25:26 +08:00
if not other:
return self
if not isinstance(other, list):
other = [other]
dag_list = []
for o in other:
if not o:
dag = nx.DiGraph()
self.last = True
dag.add_node(self)
dag_list.append(dag)
if isinstance(o, Component):
dag = nx.DiGraph()
dag.add_node(self)
dag.add_node(o)
dag.add_edge(self, o)
dag_list.append(dag)
elif isinstance(o, Chain):
dag = nx.DiGraph()
dag.add_node(self)
2023-12-11 15:34:02 +08:00
end_nodes = [
node
for node, out_degree in dag.out_degree()
if out_degree == 0 or node.last
]
start_nodes = [
node for node, in_degree in o.dag.in_degree() if in_degree == 0
]
2023-12-08 11:25:26 +08:00
if len(end_nodes) > 0 and len(start_nodes) > 0:
for end_node in end_nodes:
for start_node in start_nodes:
combined_dag.add_edge(end_node, start_node)
combined_dag = nx.compose(dag, o.dag)
dag_list.append(combined_dag)
final_dag = nx.compose_all(dag_list)
return Chain(dag=final_dag)