mirror of
https://github.com/OpenSPG/KAG.git
synced 2025-06-27 03:20:08 +00:00
133 lines
5.1 KiB
Python
133 lines
5.1 KiB
Python
# Copyright 2023 OpenSPG Authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
|
# in compliance with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
|
# or implied.
|
|
|
|
import base64
|
|
import unittest
|
|
|
|
from tabulate import tabulate
|
|
from kag.common.vectorizer.vectorizer import Vectorizer
|
|
|
|
|
|
class TestDifferentVectorizers(unittest.TestCase):
|
|
"""Different vectorizers unit test"""
|
|
|
|
def _get_bge_zh_vectorizer(self):
|
|
path = base64.b64decode("fi8uY2FjaGUvdmVjdG9yaXplci9CQUFJL2JnZS1iYXNlLXpoLXYxLjU=").decode("utf-8")
|
|
host = base64.b64decode("YWxwcy1jb21tb24ub3NzLWNuLWhhbmd6aG91LXptZi5hbGl5dW5jcy5jb20=").decode("utf-8")
|
|
model = base64.b64decode("YWxwcy9odWFpZG9uZy54aGQvRG9jdW1lbnRzL21vZGVscy9CQUFJLWJnZS1iYXNlLXpoLXYxLjUudGFyLmd6").decode("utf-8")
|
|
config = {
|
|
"vectorizer": "kag.common.vectorizer.LocalVectorizer",
|
|
"path": path,
|
|
"url": "https://%s/%s" % (host, model),
|
|
}
|
|
vectorizer = Vectorizer.from_config(config)
|
|
return vectorizer
|
|
|
|
def _get_contriever_vectorizer(self):
|
|
path = base64.b64decode("fi8uY2FjaGUvdmVjdG9yaXplci9mYWNlYm9vay9jb250cmlldmVy").decode("utf-8")
|
|
host = base64.b64decode("YWxwcy1jb21tb24ub3NzLWNuLWhhbmd6aG91LXptZi5hbGl5dW5jcy5jb20=").decode("utf-8")
|
|
model = base64.b64decode("YWxwcy9odWFpZG9uZy54aGQvRG9jdW1lbnRzL21vZGVscy9mYWNlYm9vay1jb250cmlldmVyLnRhci5neg==").decode("utf-8")
|
|
config = {
|
|
"vectorizer": "kag.common.vectorizer.ContrieverVectorizer",
|
|
"path": path,
|
|
"url": "https://%s/%s" % (host, model),
|
|
"normalize": True,
|
|
}
|
|
vectorizer = Vectorizer.from_config(config)
|
|
return vectorizer
|
|
|
|
def _get_openai_vectorizer(self):
|
|
config = {
|
|
"vectorizer": "kag.common.vectorizer.OpenAIVectorizer",
|
|
"nn_name": "text-embedding-ada-002",
|
|
"openai_api_key": "EMPTY",
|
|
"openai_api_base": "http://127.0.0.1:38080/v1"
|
|
}
|
|
vectorizer = Vectorizer.from_config(config)
|
|
return vectorizer
|
|
|
|
def _get_bge_en_vectorizer(self):
|
|
path = base64.b64decode("fi8uY2FjaGUvdmVjdG9yaXplci9CQUFJL2JnZS1iYXNlLWVuLXYxLjU=").decode("utf-8")
|
|
host = base64.b64decode("YWxwcy1jb21tb24ub3NzLWNuLWhhbmd6aG91LXptZi5hbGl5dW5jcy5jb20=").decode("utf-8")
|
|
model = base64.b64decode("YWxwcy9odWFpZG9uZy54aGQvRG9jdW1lbnRzL21vZGVscy9CQUFJLWJnZS1iYXNlLWVuLXYxLjUudGFyLmd6").decode("utf-8")
|
|
config = {
|
|
"vectorizer": "kag.common.vectorizer.LocalVectorizer",
|
|
"path": path,
|
|
"url": "https://%s/%s" % (host, model),
|
|
}
|
|
vectorizer = Vectorizer.from_config(config)
|
|
return vectorizer
|
|
|
|
def _get_bge_m3_vectorizer(self):
|
|
path = base64.b64decode("fi8uY2FjaGUvdmVjdG9yaXplci9CQUFJL2JnZS1tMw==").decode("utf-8")
|
|
host = base64.b64decode("YWxwcy1jb21tb24ub3NzLWNuLWhhbmd6aG91LXptZi5hbGl5dW5jcy5jb20=").decode("utf-8")
|
|
model = base64.b64decode("YWxwcy9odWFpZG9uZy54aGQvRG9jdW1lbnRzL21vZGVscy9CQUFJLWJnZS1tMy50YXIuZ3o=").decode("utf-8")
|
|
config = {
|
|
"vectorizer": "kag.common.vectorizer.LocalBGEM3Vectorizer",
|
|
"path": path,
|
|
"url": "https://%s/%s" % (host, model),
|
|
}
|
|
vectorizer = Vectorizer.from_config(config)
|
|
return vectorizer
|
|
|
|
def _get_vectorizers(self):
|
|
vectorizers = (
|
|
("bge_zh", self._get_bge_zh_vectorizer()),
|
|
("contriever", self._get_contriever_vectorizer()),
|
|
("openai", self._get_openai_vectorizer()),
|
|
("bge_en", self._get_bge_en_vectorizer()),
|
|
("bge_m3", self._get_bge_m3_vectorizer()),
|
|
)
|
|
return vectorizers
|
|
|
|
def setUp(self):
|
|
self.vectorizers = self._get_vectorizers()
|
|
|
|
def tearDown(self):
|
|
pass
|
|
|
|
def testVectorize(self):
|
|
inputs = [
|
|
"George Washington",
|
|
"Father of the United States",
|
|
"President Washington",
|
|
"The American George",
|
|
"Washington the Great",
|
|
]
|
|
inputs2 = [
|
|
"诸葛亮",
|
|
"卧龙先生",
|
|
"诸葛丞相",
|
|
"武乡侯",
|
|
"孔明先生",
|
|
]
|
|
headers = ("#",) + tuple(name for name, _vectorizer in self.vectorizers)
|
|
columns = []
|
|
for _name, vectorizer in self.vectorizers:
|
|
column = []
|
|
vecs = vectorizer.vectorize(inputs)
|
|
for vec in vecs:
|
|
similarity = sum(x * y for x, y in zip(vecs[0], vec))
|
|
column.append(similarity)
|
|
columns.append(column)
|
|
data = []
|
|
for i in range(len(columns[0])):
|
|
row = [i]
|
|
for column in columns:
|
|
row.append(column[i])
|
|
data.append(row)
|
|
string = tabulate(data, headers=headers)
|
|
print(string)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|