KAG/kag/common/vectorizer/batch_vectorizer.py
2024-10-24 11:46:15 +08:00

113 lines
4.2 KiB
Python

# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
class Neo4jEmbeddingVectorPlaceholder(object):
def __init__(self, number, properties, vector_field, property_key, property_value):
self._number = number
self._properties = properties
self._vector_field = vector_field
self._property_key = property_key
self._property_value = property_value
self._embedding_vector = None
def replace(self):
if self._embedding_vector is not None:
self._properties[self._vector_field] = self._embedding_vector
def __repr__(self):
return repr(self._number)
class Neo4jEmbeddingVectorManager(object):
def __init__(self):
self._placeholders = []
def _create_vector_field_name(self, property_key):
from kag.common.utils import to_snake_case
name = f"{property_key}_vector"
name = to_snake_case(name)
return "_" + name
def get_placeholder(self, properties, vector_field):
for property_key, property_value in properties.items():
field_name = self._create_vector_field_name(property_key)
if field_name != vector_field:
continue
if not property_value:
return None
if not isinstance(property_value, str):
message = f"property {property_key!r} must be string to generate embedding vector"
raise RuntimeError(message)
num = len(self._placeholders)
placeholder = Neo4jEmbeddingVectorPlaceholder(
num, properties, vector_field, property_key, property_value
)
self._placeholders.append(placeholder)
return placeholder
return None
def _get_text_batch(self):
text_batch = dict()
for placeholder in self._placeholders:
property_value = placeholder._property_value
if property_value not in text_batch:
text_batch[property_value] = list()
text_batch[property_value].append(placeholder)
return text_batch
def _generate_vectors(self, vectorizer, text_batch):
texts = list(text_batch)
vectors = vectorizer.vectorize(texts)
return vectors
def _fill_vectors(self, vectors, text_batch):
for vector, (_text, placeholders) in zip(vectors, text_batch.items()):
for placeholder in placeholders:
placeholder._embedding_vector = vector
def batch_vectorize(self, vectorizer):
text_batch = self._get_text_batch()
vectors = self._generate_vectors(vectorizer, text_batch)
self._fill_vectors(vectors, text_batch)
def patch(self):
for placeholder in self._placeholders:
placeholder.replace()
class Neo4jBatchVectorizer(object):
def __init__(self, vectorizer, vector_index_meta=None, extra_labels=("Entity",)):
self._vectorizer = vectorizer
self._extra_labels = extra_labels
self._vector_index_meta = vector_index_meta or {}
def batch_vectorize(self, node_batch):
manager = Neo4jEmbeddingVectorManager()
vector_index_meta = self._vector_index_meta
for node_item in node_batch:
label, properties = node_item
labels = [label]
if self._extra_labels:
labels.extend(self._extra_labels)
for label in labels:
if label not in vector_index_meta:
continue
for vector_field in vector_index_meta[label]:
if vector_field in properties:
continue
placeholder = manager.get_placeholder(properties, vector_field)
if placeholder is not None:
properties[vector_field] = placeholder
manager.batch_vectorize(self._vectorizer)
manager.patch()