mirror of
https://github.com/OpenSPG/openspg.git
synced 2025-07-27 19:11:36 +00:00
127 lines
4.6 KiB
Python
127 lines
4.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright 2023 Ant Group CO., Ltd.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
|
# in compliance with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
|
# or implied.
|
|
|
|
import os
|
|
from enum import Enum
|
|
from typing import Dict
|
|
|
|
from knext import rest
|
|
from knext.client.base import Client
|
|
from knext.common.class_register import register_from_package
|
|
from knext.operator.base import BaseOp
|
|
|
|
|
|
class OperatorTypeEnum(str, Enum):
|
|
LinkOp = "LINK"
|
|
FuseOp = "FUSE"
|
|
PredictOp = "PREDICT"
|
|
ExtractOp = "EXTRACT"
|
|
PromptOp = "PROMPT"
|
|
|
|
|
|
class OperatorClient(Client):
|
|
"""SPG Operator Client."""
|
|
|
|
_rest_client = rest.OperatorApi()
|
|
|
|
def __init__(self, host_addr: str = None, project_id: int = None):
|
|
super().__init__(host_addr, project_id)
|
|
if not BaseOp._has_registered and (
|
|
"KNEXT_ROOT_PATH" in os.environ
|
|
and "KNEXT_BUILDER_OPERATOR_DIR" in os.environ
|
|
):
|
|
self._builder_operator_path = os.path.join(
|
|
os.environ["KNEXT_ROOT_PATH"], os.environ["KNEXT_BUILDER_OPERATOR_DIR"]
|
|
)
|
|
|
|
register_from_package(self._builder_operator_path, BaseOp)
|
|
|
|
def publish(self, op_name: str):
|
|
"""Upload operator files and publish a new version.
|
|
If the operator has not been published, this method will create an operator overview firstly.
|
|
|
|
"""
|
|
op = BaseOp.by_name(op_name)()
|
|
|
|
operator_list = self._rest_client.operator_overview_get(name=op.name)
|
|
if len(operator_list) == 0:
|
|
self._rest_client.operator_overview_post(
|
|
operator_create_request=rest.OperatorCreateRequest(
|
|
name=op.name, desc=op.desc, operator_type=op._type
|
|
)
|
|
)
|
|
operator_id = self._rest_client.operator_overview_get(name=op.name)[0].id
|
|
else:
|
|
operator_id = operator_list[0].id
|
|
|
|
add_response = self._rest_client.operator_version_post(
|
|
project_id=self._project_id, operator_id=operator_id, file=op._local_path
|
|
)
|
|
op._version = add_response.latest_version
|
|
|
|
if op.bind_to is not None:
|
|
from knext.client.schema import SchemaClient
|
|
from knext.client.model.base import SpgTypeEnum
|
|
|
|
schema_session = SchemaClient().create_session()
|
|
spg_type = schema_session.get(op.bind_to)
|
|
if spg_type.spg_type_enum in [SpgTypeEnum.Entity, SpgTypeEnum.Event]:
|
|
spg_type.bind_link_operator(op)
|
|
elif spg_type.spg_type_enum == SpgTypeEnum.Concept:
|
|
spg_type.bind_normalize_operator(op)
|
|
else:
|
|
pass
|
|
schema_session.update_type(spg_type)
|
|
schema_session.commit()
|
|
|
|
return op
|
|
|
|
def _generate_op_config(
|
|
self, op_name: str, version: int = None, params: Dict[str, str] = None
|
|
):
|
|
"""Transforms a list of components to REST model `OperatorConfig`."""
|
|
overviews = self._rest_client.operator_overview_get(op_name)
|
|
if not overviews:
|
|
raise ValueError(
|
|
f"Operator [{op_name}] is not published."
|
|
f" Use ` knext operator publish {op_name}` to publish this operator."
|
|
)
|
|
op = None
|
|
operator_versions = self._rest_client.operator_version_get(op_name)
|
|
if not operator_versions:
|
|
raise ValueError(
|
|
f"Operator [{op_name}] is not published."
|
|
f" Use ` knext operator publish {op_name}` to publish this operator."
|
|
)
|
|
if version:
|
|
# Pull operator from server with specified version.
|
|
for operator_version in operator_versions:
|
|
if operator_version.version == version:
|
|
op = operator_version
|
|
break
|
|
if not op:
|
|
raise ValueError(
|
|
f"Operator [{op_name}] with Version [{version}] is not published."
|
|
f" Use ` knext operator publish {op_name} ` to publish this operator."
|
|
)
|
|
else:
|
|
# Pull operator from server with the latest version.
|
|
op = self._rest_client.operator_version_get(op_name)[0]
|
|
|
|
return rest.OperatorConfig(
|
|
file_path=op.file_path,
|
|
module_path=op.__module__,
|
|
class_name=op.name,
|
|
method="_handle",
|
|
params=params
|
|
)
|