KAG/tests/unit/common/registry/test_registry.py
zhuzhongshu123 e1d818dfaa refactor(all): kag v0.6 (#174)
* add path find

* fix find path

* spg guided relation extraction

* fix dict parse with same key

* rename graphalgoclient to graphclient

* rename graphalgoclient to graphclient

* file reader supports http url

* add checkpointer class

* parser supports checkpoint

* add build

* remove incorrect logs

* remove logs

* update examples

* update chain checkpointer

* vectorizer batch size set to 32

* add a zodb backended checkpointer

* add a zodb backended checkpointer

* fix zodb based checkpointer

* add thread for zodb IO

* fix(common): resolve mutlithread conflict in zodb IO

* fix(common): load existing zodb checkpoints

* update examples

* update examples

* fix zodb writer

* add docstring

* fix jieba version mismatch

* commit kag_config-tc.yaml

1、rename type to register_name
2、put a uniqe & specific name to register_name
3、rename reader to scanner
4、rename parser to reader
5、rename num_parallel to num_parallel_file, rename chain_level_num_paralle to num_parallel_chain_of_file
6、rename kag_extractor to schema_free_extractor, schema_base_extractor to schema_constraint_extractor
7、pre-define llm & vectorize_model and refer them in the yaml file

Issues to be resolved:
1、examples of event extract & spg extract
2、statistic of indexer, such as nums of nodes & edges extracted, ratio of llm invoke.
3、Exceptions such as Debt, account does not exist should be thrown in llm invoke.
4、conf of solver need to be re-examined.

* commit kag_config-tc.yaml

1、rename type to register_name
2、put a uniqe & specific name to register_name
3、rename reader to scanner
4、rename parser to reader
5、rename num_parallel to num_parallel_file, rename chain_level_num_paralle to num_parallel_chain_of_file
6、rename kag_extractor to schema_free_extractor, schema_base_extractor to schema_constraint_extractor
7、pre-define llm & vectorize_model and refer them in the yaml file

Issues to be resolved:
1、examples of event extract & spg extract
2、statistic of indexer, such as nums of nodes & edges extracted, ratio of llm invoke.
3、Exceptions such as Debt, account does not exist should be thrown in llm invoke.
4、conf of solver need to be re-examined.

* 1、fix bug in base_table_splitter

* 1、fix bug in base_table_splitter

* 1、fix bug in default_chain

* 增加solver

* add kag

* update outline splitter

* add main test

* add op

* code refactor

* add tools

* fix outline splitter

* fix outline prompt

* graph api pass

* commit with page rank

* add search api and graph api

* add markdown report

* fix vectorizer num batch compute

* add retry for vectorize model call

* update markdown reader

* update markdown reader

* update pdf reader

* raise extractor failure

* add default expr

* add log

* merge jc reader features

* rm import

* add build

* fix zodb based checkpointer

* add thread for zodb IO

* fix(common): resolve mutlithread conflict in zodb IO

* fix(common): load existing zodb checkpoints

* update examples

* update examples

* fix zodb writer

* add docstring

* fix jieba version mismatch

* commit kag_config-tc.yaml

1、rename type to register_name
2、put a uniqe & specific name to register_name
3、rename reader to scanner
4、rename parser to reader
5、rename num_parallel to num_parallel_file, rename chain_level_num_paralle to num_parallel_chain_of_file
6、rename kag_extractor to schema_free_extractor, schema_base_extractor to schema_constraint_extractor
7、pre-define llm & vectorize_model and refer them in the yaml file

Issues to be resolved:
1、examples of event extract & spg extract
2、statistic of indexer, such as nums of nodes & edges extracted, ratio of llm invoke.
3、Exceptions such as Debt, account does not exist should be thrown in llm invoke.
4、conf of solver need to be re-examined.

* commit kag_config-tc.yaml

1、rename type to register_name
2、put a uniqe & specific name to register_name
3、rename reader to scanner
4、rename parser to reader
5、rename num_parallel to num_parallel_file, rename chain_level_num_paralle to num_parallel_chain_of_file
6、rename kag_extractor to schema_free_extractor, schema_base_extractor to schema_constraint_extractor
7、pre-define llm & vectorize_model and refer them in the yaml file

Issues to be resolved:
1、examples of event extract & spg extract
2、statistic of indexer, such as nums of nodes & edges extracted, ratio of llm invoke.
3、Exceptions such as Debt, account does not exist should be thrown in llm invoke.
4、conf of solver need to be re-examined.

* 1、fix bug in base_table_splitter

* 1、fix bug in base_table_splitter

* 1、fix bug in default_chain

* update outline splitter

* add main test

* add markdown report

* code refactor

* fix outline splitter

* fix outline prompt

* update markdown reader

* fix vectorizer num batch compute

* add retry for vectorize model call

* update markdown reader

* raise extractor failure

* rm parser

* run pipeline

* add config option of whether to perform llm config check, default to false

* fix

* recover pdf reader

* several components can be null for default chain

* 支持完整qa运行

* add if

* remove unused code

* 使用chunk兜底

* excluded source relation to choose

* add generate

* default recall 10

* add local memory

* 排除相似边

* 增加保护

* 修复并发问题

* add debug logger

* 支持topk参数化

* 支持chunk截断和调整spo select 的prompt

* 增加查询请求保护

* 增加force_chunk配置

* fix entity linker algorithm

* 增加sub query改写

* fix md reader dup in test

* fix

* merge knext to kag parallel

* fix package

* 修复指标下跌问题

* scanner update

* scanner update

* add doc and update example scripts

* fix

* add bridge to spg server

* add format

* fix bridge

* update conf for baike

* disable ckpt for spg server runner

* llm invoke error default raise exceptions

* chore(version): bump version to X.Y.Z

* update default response generation prompt

* add method getSummarizationMetrics

* fix(common): fix project conf empty error

* fix typo

* 增加上报信息

* 修改main solver

* postprocessor support spg server

* 修改solver支持名

* fix language

* 修改chunker接口,增加openapi

* rename vectorizer to vectorize_model in spg server config

* generate_random_string start with gen

* add knext llm vector checker

* add knext llm vector checker

* add knext llm vector checker

* solver移除默认值

* udpate yaml and register_name for baike

* udpate yaml and register_name for baike

* remove config key check

* 修复llmmodule

* fix knext project

* udpate yaml and register_name for examples

* udpate yaml and register_name for examples

* Revert "udpate yaml and register_name for examples"

This reverts commit b3fa5ca9ba749e501133ac67bd8746027ab839d9.

* update register name

* fix

* fix

* support multiple resigter names

* update component

* update reader register names (#183)

* fix markdown reader

* fix llm client for retry

* feat(common): add processed chunk id checkpoint (#185)

* update reader register names

* add processed chunk id checkpoint

* feat(example): add example config (#186)

* update reader register names

* add processed chunk id checkpoint

* add example config file

* add max_workers parameter for getSummarizationMetrics to make it faster

* add csqa data generation script generate_data.py

* commit generated csqa builder and solver data

* add csqa basic project files

* adjust split_length and num_threads_per_chain to match lightrag settings

* ignore ckpt dirs

* add csqa evaluation script eval.py

* save evaluation scripts summarization_metrics.py and factual_correctness.py

* save LightRAG output csqa_lightrag_answers.json

* ignore KAG output csqa_kag_answers.json

* add README.md for CSQA

* fix(solver): fix solver pipeline conf (#191)

* update reader register names

* add processed chunk id checkpoint

* add example config file

* update solver pipeline config

* fix project create

* update links and file paths

* reformat csqa kag_config.yaml

* reformat csqa python files

* reformat getSummarizationMetrics and compare_summarization_answers

* fix(solver): fix solver config (#192)

* update reader register names

* add processed chunk id checkpoint

* add example config file

* update solver pipeline config

* fix project create

* fix main solver conf

* add except

* fix typo in csqa README.md

* feat(conf): support reinitialize config for call from java side (#199)

* update reader register names

* add processed chunk id checkpoint

* add example config file

* update solver pipeline config

* fix project create

* fix main solver conf

* support reinitialize config for java call

* revert default response generation prompt

* update project list

* add README.md for the hotpotqa, 2wiki and musique examples

* 增加spo检索

* turn off kag config dump by default

* turn off knext schema dump by default

* add .gitignore and fix kag_config.yaml

* add README.md for the medicine example

* add README.md for the supplychain example

* bugfix for risk mining

* use exact out

* refactor(solver): format solver code (#205)

* update reader register names

* add processed chunk id checkpoint

* add example config file

* update solver pipeline config

* fix project create

* fix main solver conf

* support reinitialize config for java call

* black format

---------

Co-authored-by: peilong <peilong.zpl@antgroup.com>
Co-authored-by: 锦呈 <zhangxinhong.zxh@antgroup.com>
Co-authored-by: zhengke.gzk <zhengke.gzk@antgroup.com>
Co-authored-by: huaidong.xhd <huaidong.xhd@antgroup.com>
2025-01-03 17:10:51 +08:00

421 lines
11 KiB
Python

# -*- coding: utf-8 -*-
import json
from typing import List, Dict, Union
from pyhocon import ConfigTree, ConfigFactory
from kag.common.registry import Registrable, Lazy, Functor
import numpy as np
def test_list_available():
from kag.interface import LLMClient
ava = LLMClient.list_available_with_detail()
print(json.dumps(ava, indent=4))
class MockModel(Registrable):
def __init__(self, name: str = "mock_model"):
self.name = name
@MockModel.register("Simple")
class Simple(MockModel):
def __init__(self, name, age=None):
pass
@MockModel.register("gaussian")
class Gaussian(MockModel):
def __init__(
self, mean: float, variance: float, noise: List[int], attr: ConfigTree
):
pass
@MockModel.register("gaussian_var")
class GaussianVar(MockModel):
def __init__(
self,
mean: float,
variance: float,
noise: List[int],
**kwargs,
):
pass
@MockModel.register("obj_gaussian")
class ObjGaussian(MockModel):
def __init__(self, mean: float, variance: float, data: List[np.ndarray]):
pass
@MockModel.register("complex_gaussian_var")
class ComplexGaussianVar(MockModel):
def __init__(self, list_gaussian: List[GaussianVar], **kwargs):
pass
@MockModel.register("complex_gaussian")
class ComplexGaussian(MockModel):
def __init__(
self, dict_gaussian: Dict[str, Gaussian], list_gaussian: List[Gaussian]
):
pass
@MockModel.register("lazy_gaussian")
class LazyGaussian(MockModel):
def __init__(
self,
gaussian: Lazy[Gaussian],
):
pass
class BaseCount(Registrable):
pass
@BaseCount.register("default", as_default=True)
@BaseCount.register("from_list_of_ints", constructor="from_list_of_ints")
@BaseCount.register("from_list_of_strings", constructor="from_list_of_strings")
@BaseCount.register("from_string_length", constructor="from_string_length")
class Count(BaseCount):
def __init__(self, count: int):
self.count = count
@classmethod
def from_list_of_ints(cls, int_list: List[int]):
ins = cls(len(int_list))
# we should add attr int_list to instance, otherwise we can't correctly
# convert it to params.
setattr(ins, "int_list", int_list)
return ins
@classmethod
def from_list_of_strings(cls, str_list: List[str]):
ins = cls(len(str_list))
setattr(ins, "str_list", str_list)
return ins
@classmethod
def from_string_length(cls, string: str):
ins = cls(len(string))
setattr(ins, "string", string)
return ins
class Type1(Registrable):
pass
class Type2(Registrable):
pass
class MixBase(Registrable):
pass
@Type1.register("sub1")
class Sub1(Type1):
def __init__(self, name1: str):
pass
@Type1.register("sub11")
class Sub11(Type1):
def __init__(self, name11: str):
pass
@Type2.register("sub2")
class Sub2(Type2):
def __init__(self, name2: str):
pass
@MixBase.register("mix1")
class Mix1(MixBase):
def __init__(self, sub: Union[Type1, Type2]):
pass
@MixBase.register("mix2")
class Mix2(MixBase):
def __init__(self, sub: Union[Sub1, Sub11, Sub2]):
pass
class Root(Registrable):
pass
@Root.register("depth1_1")
class Depth1_1(Root):
def __init__(self, depth1_1: str):
pass
@Root.register("depth1_2")
class Depth1_2(Root):
def __init__(self, depth1_2: str):
pass
@Root.register("depth2_1")
@Depth1_1.register("depth2_1")
class Depth2_1(Depth1_1):
def __init__(self, depth2_1: str):
pass
@Functor.register("simple")
def simple_func(name: "str", age: list = []):
print(f"name = {name}")
print(f"age = {age}")
return sum(age)
@Functor.register("complex")
def complex_func(gaussian: ComplexGaussian):
return len(gaussian.dict_gaussian), len(gaussian.list_gaussian)
@Functor.register("with_kwargs")
def simple_func3(**kwargs):
print(f"kwargs = {kwargs}")
return kwargs
def gen_conf():
gaussian_0 = {
"mean": 0,
"variance": 1,
"noise": [2, 3, 4],
"attr": {"name": "xxx", "age": 999},
}
gaussian_1 = {
"mean": 13,
"variance": 2,
"noise": [3, 4, 5],
"attr": {"name": "yyy", "age": 11},
}
gaussian_2 = {
"mean": 20,
"variance": 3,
"noise": [4, 5, 6],
"attr": {"name": "zzz", "age": 234},
}
gaussian_3 = {
"mean": 39,
"variance": 3,
"noise": [4, 5, 6],
"attr": {"name": "xxx", "age": 66},
}
gaussian_4 = {
"mean": 47,
"variance": 3,
"noise": [4, 5, 6],
"attr": {"name": "xxx", "age": 712},
}
params_dict = {
"dict_gaussian": {"0": gaussian_0, "1": gaussian_1},
"list_gaussian": [gaussian_2, gaussian_3, gaussian_4],
}
params = ConfigFactory.from_dict(params_dict)
return params
def test_from_param():
params = gen_conf()
model = ComplexGaussian.from_config(params)
assert model.list_gaussian[-1].mean == 47
def test_from_param_base():
params = gen_conf()
params.put("type", "complex_gaussian")
model = MockModel.from_config(params)
assert (
type(model) is ComplexGaussian
), f"expect type ComplexGaussian, got {type(model)}"
assert model.list_gaussian[-1].mean == 47
def test_to_config():
params = gen_conf()
model = ComplexGaussian.from_config(params)
reconstructed_params = model.to_config()
reconstructed_model = ComplexGaussian.from_config(reconstructed_params)
assert len(reconstructed_model.list_gaussian) == 3
assert reconstructed_model.list_gaussian[-1].mean == 47
def test_multi_constructor():
# without type key, will use default_implementation
params = ConfigFactory.from_dict({"count": 32})
ins = BaseCount.from_config(params)
reconstructed_params = ins.to_config()
assert reconstructed_params.count == 32
params = ConfigFactory.from_dict(
{"type": "from_list_of_ints", "int_list": [1, 2, 3]}
)
ins = BaseCount.from_config(params)
reconstructed_params = ins.to_config()
assert reconstructed_params.type == "from_list_of_ints"
assert reconstructed_params.int_list == [1, 2, 3]
params = ConfigFactory.from_dict(
{"type": "from_list_of_strings", "str_list": ["1", "2", "#", "*"]}
)
ins = BaseCount.from_config(params)
reconstructed_params = ins.to_config_with_constructor("from_list_of_strings")
assert reconstructed_params.type == "from_list_of_strings"
assert reconstructed_params.str_list == ["1", "2", "#", "*"]
def test_union_type():
params = ConfigFactory.from_dict(
{"type": "mix1", "sub": {"type": "sub11", "name11": "sub11"}}
)
ins = MixBase.from_config(params)
assert type(ins.sub) == Sub11
assert ins.sub.name11 == "sub11"
params = ConfigFactory.from_dict(
{"type": "mix1", "sub": {"type": "sub2", "name2": "sub2"}}
)
ins = MixBase.from_config(params)
assert type(ins.sub) == Sub2
assert ins.sub.name2 == "sub2"
# for Mix2, type of sub is not required, which has been indicated in __init__
params = ConfigFactory.from_dict({"type": "mix2", "sub": {"name2": "sub2"}})
ins = MixBase.from_config(params)
assert type(ins.sub) == Sub2
assert ins.sub.name2 == "sub2"
def test_nested():
conf = ConfigFactory.from_dict({"type": "depth1_1", "depth1_1": "zz"})
ins = Root.from_config(conf)
assert type(ins) == Depth1_1
# instantiate from intermediate class (has both parent class and subclass)
conf = ConfigFactory.from_dict({"depth1_1": "zz"})
ins = Depth1_1.from_config(conf)
assert type(ins) == Depth1_1
# instantiate from leaf class (have no subclass)
conf = ConfigFactory.from_dict({"type": "depth1_2", "depth1_2": "zz"})
ins = Depth1_2.from_config(conf)
assert type(ins) == Depth1_2
# instantiate from root class[require extra register declarition]
conf = ConfigFactory.from_dict({"type": "depth2_1", "depth2_1": "zz"})
ins = Root.from_config(conf)
assert type(ins) == Depth2_1
# instantiate from parent class
conf = ConfigFactory.from_dict({"type": "depth2_1", "depth2_1": "zz"})
ins = Depth1_1.from_config(conf)
assert type(ins) == Depth2_1
def test_pass_dict():
conf = {"type": "depth1_1", "depth1_1": "zz"}
ins = Root.from_config(conf)
assert type(ins) == Depth1_1
def test_with_kwargs():
conf = {
"type": "gaussian_var",
"mean": 1.1,
"variance": "2.2",
"noise": [1, 2, 3],
"less": "more",
"x": "y",
}
res = GaussianVar.from_config(conf)
assert res.less == "more" and res.x == "y"
conf = {
"type": "complex_gaussian_var",
"less": "more",
"x": "y",
"list_gaussian": [
{"mean": 0.7, "variance": 1.1, "noise": [1, 2, 3], "less": "more", "x": "y"}
],
}
ComplexGaussianVar.list_available()
res = ComplexGaussianVar.from_config(conf)
assert (
res.less == "more"
and res.x == "y"
and res.list_gaussian[0].less == "more"
and res.list_gaussian[0].x == "y"
)
def test_with_obj():
conf = {
"type": "gaussian_var",
"mean": 1.1,
"variance": "2.2",
"noise": [1, 2, 3],
"less": "more",
"x": "y",
}
res = GaussianVar.from_config(conf)
assert res.less == "more" and res.x == "y"
conf = {
"type": "complex_gaussian_var",
"less": "more",
"x": "y",
"list_gaussian": [
{"mean": 0.7, "variance": 1.1, "noise": [1, 2, 3], "less": "more", "x": "y"}
],
}
res = ComplexGaussianVar.from_config(conf)
# use object instead of config
conf["list_gaussian"] = res.list_gaussian
res2 = ComplexGaussianVar.from_config(conf)
assert id(res.list_gaussian[0]) == id(
res2.list_gaussian[0]
), "The two objects are different!!"
data = np.random.rand(128)
conf = {"mean": 1.1, "variance": 2.2, "data": data}
res = ObjGaussian.from_config(conf)
assert data is res.data, "The two objects are different!!"
def test_functor():
simple_conf = ConfigFactory.from_dict(
{"type": "simple", "name": "pyfunc", "age": [1, 2, 3]}
)
func = Functor.from_config(simple_conf)
reconstructed_conf = func.to_config()
reconstructed_func = Functor.from_config(reconstructed_conf)
assert reconstructed_func() == 6
complex_conf = ConfigFactory.from_dict({"type": "complex", "gaussian": gen_conf()})
func = Functor.from_config(complex_conf)
reconstructed_conf = func.to_config()
reconstructed_func = Functor.from_config(reconstructed_conf)
assert reconstructed_func() == (2, 3)
with_kwargs_conf = ConfigFactory.from_dict(
{"type": "with_kwargs", "name": "pyfunc"}
)
func = Functor.from_config(with_kwargs_conf)
kwargs = func()
assert kwargs["name"] == "pyfunc"