openspg/python/knext/client/operator.py

127 lines
4.6 KiB
Python
Raw Normal View History

2023-10-26 10:34:08 +08:00
# -*- coding: utf-8 -*-
# Copyright 2023 Ant Group CO., Ltd.
2023-10-26 10:34:08 +08:00
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
2023-10-26 10:34:08 +08:00
#
# http://www.apache.org/licenses/LICENSE-2.0
2023-10-26 10:34:08 +08:00
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
2023-10-26 10:34:08 +08:00
import os
from enum import Enum
from typing import Dict
from knext import rest
2023-12-06 17:26:39 +08:00
from knext.client.base import Client
2023-10-26 10:34:08 +08:00
from knext.common.class_register import register_from_package
2023-12-11 23:13:19 +08:00
from knext.operator.base import BaseOp
2023-10-26 10:34:08 +08:00
class OperatorTypeEnum(str, Enum):
2023-12-21 17:38:20 +08:00
LinkOp = "LINK"
FuseOp = "FUSE"
PredictOp = "PREDICT"
ExtractOp = "EXTRACT"
PromptOp = "PROMPT"
2023-10-26 10:34:08 +08:00
2023-12-06 17:26:39 +08:00
class OperatorClient(Client):
2023-11-21 15:17:02 +08:00
"""SPG Operator Client."""
2023-10-26 10:34:08 +08:00
2023-12-18 13:46:44 +08:00
_rest_client = rest.OperatorApi()
2023-12-11 23:13:19 +08:00
def __init__(self, host_addr: str = None, project_id: int = None):
super().__init__(host_addr, project_id)
2023-12-18 13:46:44 +08:00
if not BaseOp._has_registered and (
2023-12-13 15:22:41 +08:00
"KNEXT_ROOT_PATH" in os.environ
and "KNEXT_BUILDER_OPERATOR_DIR" in os.environ
):
2023-12-11 23:13:19 +08:00
self._builder_operator_path = os.path.join(
os.environ["KNEXT_ROOT_PATH"], os.environ["KNEXT_BUILDER_OPERATOR_DIR"]
)
2023-10-26 10:34:08 +08:00
2023-12-18 13:46:44 +08:00
register_from_package(self._builder_operator_path, BaseOp)
2023-10-26 10:34:08 +08:00
def publish(self, op_name: str):
"""Upload operator files and publish a new version.
If the operator has not been published, this method will create an operator overview firstly.
"""
op = BaseOp.by_name(op_name)()
2023-12-18 13:46:44 +08:00
operator_list = self._rest_client.operator_overview_get(name=op.name)
2023-10-26 10:34:08 +08:00
if len(operator_list) == 0:
2023-12-18 13:46:44 +08:00
self._rest_client.operator_overview_post(
2023-10-26 10:34:08 +08:00
operator_create_request=rest.OperatorCreateRequest(
name=op.name, desc=op.desc, operator_type=op._type
)
)
2023-12-18 13:46:44 +08:00
operator_id = self._rest_client.operator_overview_get(name=op.name)[0].id
2023-10-26 10:34:08 +08:00
else:
operator_id = operator_list[0].id
2023-12-18 13:46:44 +08:00
add_response = self._rest_client.operator_version_post(
2023-10-26 10:34:08 +08:00
project_id=self._project_id, operator_id=operator_id, file=op._local_path
)
op._version = add_response.latest_version
if op.bind_to is not None:
2023-12-21 17:38:20 +08:00
from knext.client.schema import SchemaClient
2023-12-06 17:26:39 +08:00
from knext.client.model.base import SpgTypeEnum
2023-10-26 10:34:08 +08:00
2023-12-21 17:38:20 +08:00
schema_session = SchemaClient().create_session()
2023-10-26 10:34:08 +08:00
spg_type = schema_session.get(op.bind_to)
if spg_type.spg_type_enum in [SpgTypeEnum.Entity, SpgTypeEnum.Event]:
spg_type.bind_link_operator(op)
elif spg_type.spg_type_enum == SpgTypeEnum.Concept:
spg_type.bind_normalize_operator(op)
else:
pass
schema_session.update_type(spg_type)
schema_session.commit()
return op
def _generate_op_config(
self, op_name: str, version: int = None, params: Dict[str, str] = None
):
"""Transforms a list of components to REST model `OperatorConfig`."""
2023-12-18 13:46:44 +08:00
overviews = self._rest_client.operator_overview_get(op_name)
2023-10-26 10:34:08 +08:00
if not overviews:
raise ValueError(
f"Operator [{op_name}] is not published."
f" Use ` knext operator publish {op_name}` to publish this operator."
)
op = None
2023-12-18 13:46:44 +08:00
operator_versions = self._rest_client.operator_version_get(op_name)
2023-10-26 10:34:08 +08:00
if not operator_versions:
raise ValueError(
f"Operator [{op_name}] is not published."
f" Use ` knext operator publish {op_name}` to publish this operator."
)
if version:
# Pull operator from server with specified version.
for operator_version in operator_versions:
if operator_version.version == version:
op = operator_version
break
if not op:
raise ValueError(
f"Operator [{op_name}] with Version [{version}] is not published."
f" Use ` knext operator publish {op_name} ` to publish this operator."
)
else:
# Pull operator from server with the latest version.
2023-12-18 13:46:44 +08:00
op = self._rest_client.operator_version_get(op_name)[0]
2023-10-26 10:34:08 +08:00
return rest.OperatorConfig(
2023-12-18 13:46:44 +08:00
file_path=op.file_path,
module_path=op.__module__,
class_name=op.name,
method="_handle",
params=params
2023-10-26 10:34:08 +08:00
)