mirror of
https://github.com/OpenSPG/KAG.git
synced 2025-06-27 03:20:08 +00:00
support disabling vector generation (#486)
This commit is contained in:
parent
a7fd51d138
commit
6745cf1744
@ -11,7 +11,7 @@
|
||||
# or implied.
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from tenacity import stop_after_attempt, retry
|
||||
|
||||
from kag.builder.model.sub_graph import SubGraph
|
||||
@ -43,11 +43,15 @@ class EmbeddingVectorPlaceholder(object):
|
||||
|
||||
|
||||
class EmbeddingVectorManager(object):
|
||||
def __init__(self):
|
||||
def __init__(self, disable_generation=None):
|
||||
self._placeholders = []
|
||||
self._disable_generation = frozenset(disable_generation or [])
|
||||
|
||||
def get_placeholder(self, properties, vector_field):
|
||||
def get_placeholder(self, label, properties, vector_field):
|
||||
for property_key, property_value in properties.items():
|
||||
disable_prop_key = f"{label}.{property_key}"
|
||||
if disable_prop_key in self._disable_generation:
|
||||
continue
|
||||
field_name = get_vector_field_name(property_key)
|
||||
if field_name != vector_field:
|
||||
continue
|
||||
@ -126,13 +130,20 @@ class EmbeddingVectorManager(object):
|
||||
|
||||
|
||||
class EmbeddingVectorGenerator(object):
|
||||
def __init__(self, vectorizer, vector_index_meta=None, extra_labels=("Entity",)):
|
||||
def __init__(
|
||||
self,
|
||||
vectorizer,
|
||||
vector_index_meta=None,
|
||||
disable_generation=None,
|
||||
extra_labels=("Entity",),
|
||||
):
|
||||
self._vectorizer = vectorizer
|
||||
self._extra_labels = extra_labels
|
||||
self._vector_index_meta = vector_index_meta or {}
|
||||
self._disable_generation = disable_generation
|
||||
|
||||
def batch_generate(self, node_batch, batch_size=32):
|
||||
manager = EmbeddingVectorManager()
|
||||
manager = EmbeddingVectorManager(self._disable_generation)
|
||||
vector_index_meta = self._vector_index_meta
|
||||
for node_item in node_batch:
|
||||
label, properties = node_item
|
||||
@ -145,14 +156,16 @@ class EmbeddingVectorGenerator(object):
|
||||
for vector_field in vector_index_meta[label]:
|
||||
if vector_field in properties:
|
||||
continue
|
||||
placeholder = manager.get_placeholder(properties, vector_field)
|
||||
placeholder = manager.get_placeholder(
|
||||
label, properties, vector_field
|
||||
)
|
||||
if placeholder is not None:
|
||||
properties[vector_field] = placeholder
|
||||
manager.batch_generate(self._vectorizer, batch_size)
|
||||
manager.patch()
|
||||
|
||||
async def abatch_generate(self, node_batch, batch_size=32):
|
||||
manager = EmbeddingVectorManager()
|
||||
manager = EmbeddingVectorManager(self._disable_generation)
|
||||
vector_index_meta = self._vector_index_meta
|
||||
for node_item in node_batch:
|
||||
label, properties = node_item
|
||||
@ -165,7 +178,9 @@ class EmbeddingVectorGenerator(object):
|
||||
for vector_field in vector_index_meta[label]:
|
||||
if vector_field in properties:
|
||||
continue
|
||||
placeholder = manager.get_placeholder(properties, vector_field)
|
||||
placeholder = manager.get_placeholder(
|
||||
label, properties, vector_field
|
||||
)
|
||||
if placeholder is not None:
|
||||
properties[vector_field] = placeholder
|
||||
await manager.abatch_generate(self._vectorizer, batch_size)
|
||||
@ -189,7 +204,12 @@ class BatchVectorizer(VectorizerABC):
|
||||
batch_size (int): The size of the batches in which to process the nodes.
|
||||
"""
|
||||
|
||||
def __init__(self, vectorize_model: VectorizeModelABC, batch_size: int = 32):
|
||||
def __init__(
|
||||
self,
|
||||
vectorize_model: VectorizeModelABC,
|
||||
batch_size: int = 32,
|
||||
disable_generation: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Initializes the BatchVectorizer with the specified vectorization model and batch size.
|
||||
|
||||
@ -203,6 +223,7 @@ class BatchVectorizer(VectorizerABC):
|
||||
self.vec_meta = self._init_vec_meta()
|
||||
self.vectorize_model = vectorize_model
|
||||
self.batch_size = batch_size
|
||||
self.disable_generation = disable_generation
|
||||
|
||||
def _init_vec_meta(self):
|
||||
"""
|
||||
@ -246,7 +267,9 @@ class BatchVectorizer(VectorizerABC):
|
||||
properties.update(node.properties)
|
||||
node_list.append((node, properties))
|
||||
node_batch.append((node.label, properties.copy()))
|
||||
generator = EmbeddingVectorGenerator(self.vectorize_model, self.vec_meta)
|
||||
generator = EmbeddingVectorGenerator(
|
||||
self.vectorize_model, self.vec_meta, self.disable_generation
|
||||
)
|
||||
generator.batch_generate(node_batch, self.batch_size)
|
||||
for (node, properties), (_node_label, new_properties) in zip(
|
||||
node_list, node_batch
|
||||
@ -277,7 +300,9 @@ class BatchVectorizer(VectorizerABC):
|
||||
properties.update(node.properties)
|
||||
node_list.append((node, properties))
|
||||
node_batch.append((node.label, properties.copy()))
|
||||
generator = EmbeddingVectorGenerator(self.vectorize_model, self.vec_meta)
|
||||
generator = EmbeddingVectorGenerator(
|
||||
self.vectorize_model, self.vec_meta, self.disable_generation
|
||||
)
|
||||
await generator.abatch_generate(node_batch, self.batch_size)
|
||||
for (node, properties), (_node_label, new_properties) in zip(
|
||||
node_list, node_batch
|
||||
|
@ -10,7 +10,7 @@
|
||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
# or implied.
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from kag.builder.component.vectorizer.batch_vectorizer import EmbeddingVectorGenerator
|
||||
from tenacity import stop_after_attempt, retry
|
||||
|
||||
@ -39,7 +39,12 @@ class AffairBatchVectorizer(VectorizerABC):
|
||||
batch_size (int): The size of the batches in which to process the nodes.
|
||||
"""
|
||||
|
||||
def __init__(self, vectorize_model: VectorizeModelABC, batch_size: int = 32):
|
||||
def __init__(
|
||||
self,
|
||||
vectorize_model: VectorizeModelABC,
|
||||
batch_size: int = 32,
|
||||
disable_generation: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Initializes the BatchVectorizer with the specified vectorization model and batch size.
|
||||
|
||||
@ -53,6 +58,7 @@ class AffairBatchVectorizer(VectorizerABC):
|
||||
self.vec_meta = self._init_vec_meta()
|
||||
self.vectorize_model = vectorize_model
|
||||
self.batch_size = batch_size
|
||||
self.disable_generation = disable_generation
|
||||
|
||||
def _init_vec_meta(self):
|
||||
"""
|
||||
@ -95,7 +101,9 @@ class AffairBatchVectorizer(VectorizerABC):
|
||||
properties.update(node.properties)
|
||||
node_list.append((node, properties))
|
||||
node_batch.append((node.label, properties.copy()))
|
||||
generator = EmbeddingVectorGenerator(self.vectorize_model, self.vec_meta)
|
||||
generator = EmbeddingVectorGenerator(
|
||||
self.vectorize_model, self.vec_meta, self.disable_generation
|
||||
)
|
||||
generator.batch_generate(node_batch, self.batch_size)
|
||||
for (node, properties), (_node_label, new_properties) in zip(
|
||||
node_list, node_batch
|
||||
|
Loading…
x
Reference in New Issue
Block a user