openspg/python/knext/tests/operators/test_operators.py
Qu 4129a1a586
feat(knext): add knext feature and refactor some code (#58)
Co-authored-by: didicout <julin.jl@antgroup.com>
Co-authored-by: huaidong.xhd <huaidong.xhd@antgroup.com>
Co-authored-by: Qu <qu266141@antgroup.com>
Co-authored-by: baifuyu <fuyu.bfy@antgroup.com>
2024-01-06 15:55:10 +08:00

141 lines
5.1 KiB
Python

# -*- coding: utf-8 -*-
# Copyright 2023 Ant Group CO., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
from knext.operator.spg_record import SPGRecord
from knext.common.class_register import register_from_package
from knext.operator.base import BaseOp
from knext.operator.op import ExtractOp, LinkOp, FuseOp, PredictOp
register_from_package("./operators", BaseOp)
register_from_package("./operators", ExtractOp)
register_from_package("./operators", LinkOp)
register_from_package("./operators", FuseOp)
register_from_package("./operators", PredictOp)
def get_test_extract_data():
record = {
"type": "Company",
"properties": '{"phone": "+86-12345678", "addr": "China"}',
}
return record
def get_test_record():
properties = {"phone": "+86-12345678", "addr": "China", "name": "taobao"}
return SPGRecord("Company", properties)
def test_get_op():
op = BaseOp.by_name("DummyOp")()
assert isinstance(
op, BaseOp
), f"op should be subclass of BaseOp, got {type(BaseOp)}"
def test_extract_op():
record = get_test_extract_data()
op = ExtractOp.by_name("TestExtractOp")()
op_out = op._handle(*(record,))
spg_records = op_out["data"]
assert isinstance(
spg_records, list
), f"output should be list, got {type(spg_records)}"
assert (
len(spg_records) == 1
), f"expected output length should be 1, got {len(spg_records)}"
assert spg_records[0]["spgTypeName"] == record["type"]
properties = json.loads(record["properties"])
for k, v in spg_records[0]["properties"].items():
assert (
properties[k] == v
), f"value of property {k} should be {properties[k]}, got {v}"
def test_link_op():
subject_record = get_test_record()
op = LinkOp.by_name("TestLinkOp")()
name = subject_record.get_property("name")
op_out = op._handle(*(name, subject_record.to_dict()))
records = op_out["data"]
assert (
len(records) == op.num_outputs
), "length of output records should be {op.num_outputs}, got {len(records)}"
records = [SPGRecord.from_dict(x) for x in records]
for i in range(op.num_outputs):
assert records[i].get_property("index") == str(
i + 1
), f'value of property index should be {i+1}, got {records[i].get_property("index")}'
idx_p = "indexed_property"
assert (
records[i].get_property("indexed_property") == f"{name}_{i+1}"
), f"value of property {idx_p} should be {name}_{i+1}, got {records[i].get_property(idx_p)}"
def test_fuse_op():
subject_record = get_test_record()
op = FuseOp.by_name("TestFuseOp")()
name = subject_record.get_property("name")
op_out = op._handle(*([subject_record.to_dict()],))
print(f"op_out = {op_out}")
records = op_out["data"]
records = [SPGRecord.from_dict(x) for x in records]
assert (
len(records) == op.num_outputs
), "length of output records should be {op.num_outputs}, got {len(records)}"
for i in range(op.num_outputs):
assert records[i].get_property("index") == str(
i + 1
), f'value of property index should be {i}, got {records[i].get_property("index")}'
assert (
records[i].get_property("name") == f"{name}{i+1}"
), f'value of property name should be f"{name}{i+1}", got {records[i].get_property("name")}'
for k, v in subject_record.properties.items():
if k == "name":
continue
assert (
records[i].get_property(k) == v
), f"value of property {k} should be {v}, got {records[i].get_property(k)}"
def test_predict_op():
subject_record = get_test_record()
op = PredictOp.by_name("TestPredictOp")()
name = subject_record.get_property("name")
op_out = op._handle(*(subject_record.to_dict(),))
print(f"op_out = {op_out}")
records = op_out["data"]
records = [SPGRecord.from_dict(x) for x in records]
assert (
len(records) == op.num_outputs
), "length of output records should be {op.num_outputs}, got {len(records)}"
for i in range(op.num_outputs):
assert records[i].get_property("index") == str(
i + 1
), f'value of property index should be {i+1}, got {records[i].get_property("index")}'
assert (
records[i].get_property("name") == f"{name}{i+1}"
), f'value of property name should be f"{name}{i+1}", got {records[i].get_property("name")}'
for k, v in subject_record.properties.items():
if k == "name":
continue
assert (
records[i].get_property(k) == v
), f"value of property {k} should be {v}, got {records[i].get_property(k)}"