2024-10-24 11:46:15 +08:00
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
# Copyright 2023 OpenSPG Authors
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
|
|
|
# in compliance with the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
|
|
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
|
|
|
# or implied.
|
|
|
|
import os
|
|
|
|
from collections import defaultdict
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
from kag.builder.model.sub_graph import SubGraph
|
|
|
|
from knext.common.base.runnable import Input, Output
|
2024-10-24 17:28:14 +08:00
|
|
|
from kag.common.vectorizer import Vectorizer
|
2024-10-24 11:46:15 +08:00
|
|
|
from kag.interface.builder.vectorizer_abc import VectorizerABC
|
|
|
|
from knext.schema.client import SchemaClient
|
|
|
|
from knext.project.client import ProjectClient
|
|
|
|
from knext.schema.model.base import IndexTypeEnum
|
|
|
|
|
|
|
|
|
2024-10-24 17:28:14 +08:00
|
|
|
class EmbeddingVectorPlaceholder(object):
|
|
|
|
def __init__(self, number, properties, vector_field, property_key, property_value):
|
|
|
|
self._number = number
|
|
|
|
self._properties = properties
|
|
|
|
self._vector_field = vector_field
|
|
|
|
self._property_key = property_key
|
|
|
|
self._property_value = property_value
|
|
|
|
self._embedding_vector = None
|
|
|
|
|
|
|
|
def replace(self):
|
|
|
|
if self._embedding_vector is not None:
|
|
|
|
self._properties[self._vector_field] = self._embedding_vector
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return repr(self._number)
|
|
|
|
|
|
|
|
|
|
|
|
class EmbeddingVectorManager(object):
|
|
|
|
def __init__(self):
|
|
|
|
self._placeholders = []
|
|
|
|
|
|
|
|
def _create_vector_field_name(self, property_key):
|
|
|
|
from kag.common.utils import to_snake_case
|
|
|
|
|
|
|
|
name = f"{property_key}_vector"
|
|
|
|
name = to_snake_case(name)
|
|
|
|
return "_" + name
|
|
|
|
|
|
|
|
def get_placeholder(self, properties, vector_field):
|
|
|
|
for property_key, property_value in properties.items():
|
|
|
|
field_name = self._create_vector_field_name(property_key)
|
|
|
|
if field_name != vector_field:
|
|
|
|
continue
|
|
|
|
if not property_value:
|
|
|
|
return None
|
|
|
|
if not isinstance(property_value, str):
|
|
|
|
message = f"property {property_key!r} must be string to generate embedding vector"
|
|
|
|
raise RuntimeError(message)
|
|
|
|
num = len(self._placeholders)
|
|
|
|
placeholder = EmbeddingVectorPlaceholder(
|
|
|
|
num, properties, vector_field, property_key, property_value
|
|
|
|
)
|
|
|
|
self._placeholders.append(placeholder)
|
|
|
|
return placeholder
|
|
|
|
return None
|
|
|
|
|
|
|
|
def _get_text_batch(self):
|
|
|
|
text_batch = dict()
|
|
|
|
for placeholder in self._placeholders:
|
|
|
|
property_value = placeholder._property_value
|
|
|
|
if property_value not in text_batch:
|
|
|
|
text_batch[property_value] = list()
|
|
|
|
text_batch[property_value].append(placeholder)
|
|
|
|
return text_batch
|
|
|
|
|
|
|
|
def _generate_vectors(self, vectorizer, text_batch):
|
|
|
|
texts = list(text_batch)
|
|
|
|
vectors = vectorizer.vectorize(texts)
|
|
|
|
return vectors
|
|
|
|
|
|
|
|
def _fill_vectors(self, vectors, text_batch):
|
|
|
|
for vector, (_text, placeholders) in zip(vectors, text_batch.items()):
|
|
|
|
for placeholder in placeholders:
|
|
|
|
placeholder._embedding_vector = vector
|
|
|
|
|
|
|
|
def batch_generate(self, vectorizer):
|
|
|
|
text_batch = self._get_text_batch()
|
|
|
|
vectors = self._generate_vectors(vectorizer, text_batch)
|
|
|
|
self._fill_vectors(vectors, text_batch)
|
|
|
|
|
|
|
|
def patch(self):
|
|
|
|
for placeholder in self._placeholders:
|
|
|
|
placeholder.replace()
|
|
|
|
|
|
|
|
|
|
|
|
class EmbeddingVectorGenerator(object):
|
|
|
|
def __init__(self, vectorizer, vector_index_meta=None, extra_labels=("Entity",)):
|
|
|
|
self._vectorizer = vectorizer
|
|
|
|
self._extra_labels = extra_labels
|
|
|
|
self._vector_index_meta = vector_index_meta or {}
|
|
|
|
|
|
|
|
def batch_generate(self, node_batch):
|
|
|
|
manager = EmbeddingVectorManager()
|
|
|
|
vector_index_meta = self._vector_index_meta
|
|
|
|
for node_item in node_batch:
|
|
|
|
label, properties = node_item
|
|
|
|
labels = [label]
|
|
|
|
if self._extra_labels:
|
|
|
|
labels.extend(self._extra_labels)
|
|
|
|
for label in labels:
|
|
|
|
if label not in vector_index_meta:
|
|
|
|
continue
|
|
|
|
for vector_field in vector_index_meta[label]:
|
|
|
|
if vector_field in properties:
|
|
|
|
continue
|
|
|
|
placeholder = manager.get_placeholder(properties, vector_field)
|
|
|
|
if placeholder is not None:
|
|
|
|
properties[vector_field] = placeholder
|
|
|
|
manager.batch_generate(self._vectorizer)
|
|
|
|
manager.patch()
|
|
|
|
|
|
|
|
|
2024-10-24 11:46:15 +08:00
|
|
|
class BatchVectorizer(VectorizerABC):
|
|
|
|
|
|
|
|
def __init__(self, project_id: str = None, **kwargs):
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
self.project_id = project_id or os.getenv("KAG_PROJECT_ID")
|
|
|
|
self._init_graph_store()
|
|
|
|
self.vec_meta = self._init_vec_meta()
|
|
|
|
self.vectorizer = Vectorizer.from_config(self.vectorizer_config)
|
|
|
|
|
|
|
|
def _init_graph_store(self):
|
|
|
|
"""
|
|
|
|
Initializes the Graph Store client.
|
|
|
|
|
|
|
|
This method retrieves the graph store configuration from environment variables and the project ID.
|
|
|
|
It then fetches the project configuration using the project ID and updates the graph store configuration
|
|
|
|
with any additional settings from the project. Finally, it creates and initializes the graph store client
|
|
|
|
using the updated configuration.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
project_id (str): The id of project.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
GraphStore
|
|
|
|
"""
|
|
|
|
graph_store_config = eval(os.getenv("KAG_GRAPH_STORE", "{}"))
|
|
|
|
vectorizer_config = eval(os.getenv("KAG_VECTORIZER", "{}"))
|
|
|
|
config = ProjectClient().get_config(self.project_id)
|
|
|
|
graph_store_config.update(config.get("graph_store", {}))
|
|
|
|
vectorizer_config.update(config.get("vectorizer", {}))
|
|
|
|
self.vectorizer_config = vectorizer_config
|
|
|
|
|
|
|
|
def _init_vec_meta(self):
|
|
|
|
vec_meta = defaultdict(list)
|
|
|
|
schema_client = SchemaClient(project_id=self.project_id)
|
|
|
|
spg_types = schema_client.load()
|
|
|
|
for type_name, spg_type in spg_types.items():
|
|
|
|
for prop_name, prop in spg_type.properties.items():
|
|
|
|
if prop_name == "name" or prop.index_type in [IndexTypeEnum.Vector, IndexTypeEnum.TextAndVector]:
|
|
|
|
vec_meta[type_name].append(self._create_vector_field_name(prop_name))
|
|
|
|
return vec_meta
|
|
|
|
|
|
|
|
def _create_vector_field_name(self, property_key):
|
|
|
|
from kag.common.utils import to_snake_case
|
|
|
|
|
|
|
|
name = f"{property_key}_vector"
|
|
|
|
name = to_snake_case(name)
|
|
|
|
return "_" + name
|
|
|
|
|
2024-10-24 17:28:14 +08:00
|
|
|
def _generate_embedding_vectors(self, vectorizer: Vectorizer, input: SubGraph) -> SubGraph:
|
2024-10-24 11:46:15 +08:00
|
|
|
node_list = []
|
|
|
|
node_batch = []
|
|
|
|
for node in input.nodes:
|
|
|
|
if not node.id or not node.name:
|
|
|
|
continue
|
|
|
|
properties = {"id": node.id, "name": node.name}
|
|
|
|
properties.update(node.properties)
|
|
|
|
node_list.append((node, properties))
|
|
|
|
node_batch.append((node.label, properties.copy()))
|
2024-10-24 17:28:14 +08:00
|
|
|
generator = EmbeddingVectorGenerator(vectorizer, self.vec_meta)
|
|
|
|
generator.batch_generate(node_batch)
|
2024-10-24 11:46:15 +08:00
|
|
|
for (node, properties), (_node_label, new_properties) in zip(
|
|
|
|
node_list, node_batch
|
|
|
|
):
|
|
|
|
for key, value in properties.items():
|
|
|
|
if key in new_properties and new_properties[key] == value:
|
|
|
|
del new_properties[key]
|
|
|
|
node.properties.update(new_properties)
|
|
|
|
return input
|
|
|
|
|
|
|
|
def invoke(self, input: Input, **kwargs) -> List[Output]:
|
2024-10-24 17:28:14 +08:00
|
|
|
modified_input = self._generate_embedding_vectors(self.vectorizer, input)
|
2024-10-24 11:46:15 +08:00
|
|
|
return [modified_input]
|