support disabling vector generation (#486)

This commit is contained in:
xionghuaidong 2025-04-24 13:36:08 +08:00 committed by GitHub
parent a7fd51d138
commit 6745cf1744
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 14 deletions

View File

@ -11,7 +11,7 @@
# or implied.
import asyncio
from collections import defaultdict
from typing import List
from typing import List, Optional
from tenacity import stop_after_attempt, retry
from kag.builder.model.sub_graph import SubGraph
@ -43,11 +43,15 @@ class EmbeddingVectorPlaceholder(object):
class EmbeddingVectorManager(object):
def __init__(self):
def __init__(self, disable_generation=None):
self._placeholders = []
self._disable_generation = frozenset(disable_generation or [])
def get_placeholder(self, properties, vector_field):
def get_placeholder(self, label, properties, vector_field):
for property_key, property_value in properties.items():
disable_prop_key = f"{label}.{property_key}"
if disable_prop_key in self._disable_generation:
continue
field_name = get_vector_field_name(property_key)
if field_name != vector_field:
continue
@ -126,13 +130,20 @@ class EmbeddingVectorManager(object):
class EmbeddingVectorGenerator(object):
def __init__(self, vectorizer, vector_index_meta=None, extra_labels=("Entity",)):
def __init__(
self,
vectorizer,
vector_index_meta=None,
disable_generation=None,
extra_labels=("Entity",),
):
self._vectorizer = vectorizer
self._extra_labels = extra_labels
self._vector_index_meta = vector_index_meta or {}
self._disable_generation = disable_generation
def batch_generate(self, node_batch, batch_size=32):
manager = EmbeddingVectorManager()
manager = EmbeddingVectorManager(self._disable_generation)
vector_index_meta = self._vector_index_meta
for node_item in node_batch:
label, properties = node_item
@ -145,14 +156,16 @@ class EmbeddingVectorGenerator(object):
for vector_field in vector_index_meta[label]:
if vector_field in properties:
continue
placeholder = manager.get_placeholder(properties, vector_field)
placeholder = manager.get_placeholder(
label, properties, vector_field
)
if placeholder is not None:
properties[vector_field] = placeholder
manager.batch_generate(self._vectorizer, batch_size)
manager.patch()
async def abatch_generate(self, node_batch, batch_size=32):
manager = EmbeddingVectorManager()
manager = EmbeddingVectorManager(self._disable_generation)
vector_index_meta = self._vector_index_meta
for node_item in node_batch:
label, properties = node_item
@ -165,7 +178,9 @@ class EmbeddingVectorGenerator(object):
for vector_field in vector_index_meta[label]:
if vector_field in properties:
continue
placeholder = manager.get_placeholder(properties, vector_field)
placeholder = manager.get_placeholder(
label, properties, vector_field
)
if placeholder is not None:
properties[vector_field] = placeholder
await manager.abatch_generate(self._vectorizer, batch_size)
@ -189,7 +204,12 @@ class BatchVectorizer(VectorizerABC):
batch_size (int): The size of the batches in which to process the nodes.
"""
def __init__(self, vectorize_model: VectorizeModelABC, batch_size: int = 32):
def __init__(
self,
vectorize_model: VectorizeModelABC,
batch_size: int = 32,
disable_generation: Optional[List[str]] = None,
):
"""
Initializes the BatchVectorizer with the specified vectorization model and batch size.
@ -203,6 +223,7 @@ class BatchVectorizer(VectorizerABC):
self.vec_meta = self._init_vec_meta()
self.vectorize_model = vectorize_model
self.batch_size = batch_size
self.disable_generation = disable_generation
def _init_vec_meta(self):
"""
@ -246,7 +267,9 @@ class BatchVectorizer(VectorizerABC):
properties.update(node.properties)
node_list.append((node, properties))
node_batch.append((node.label, properties.copy()))
generator = EmbeddingVectorGenerator(self.vectorize_model, self.vec_meta)
generator = EmbeddingVectorGenerator(
self.vectorize_model, self.vec_meta, self.disable_generation
)
generator.batch_generate(node_batch, self.batch_size)
for (node, properties), (_node_label, new_properties) in zip(
node_list, node_batch
@ -277,7 +300,9 @@ class BatchVectorizer(VectorizerABC):
properties.update(node.properties)
node_list.append((node, properties))
node_batch.append((node.label, properties.copy()))
generator = EmbeddingVectorGenerator(self.vectorize_model, self.vec_meta)
generator = EmbeddingVectorGenerator(
self.vectorize_model, self.vec_meta, self.disable_generation
)
await generator.abatch_generate(node_batch, self.batch_size)
for (node, properties), (_node_label, new_properties) in zip(
node_list, node_batch

View File

@ -10,7 +10,7 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from collections import defaultdict
from typing import List
from typing import List, Optional
from kag.builder.component.vectorizer.batch_vectorizer import EmbeddingVectorGenerator
from tenacity import stop_after_attempt, retry
@ -39,7 +39,12 @@ class AffairBatchVectorizer(VectorizerABC):
batch_size (int): The size of the batches in which to process the nodes.
"""
def __init__(self, vectorize_model: VectorizeModelABC, batch_size: int = 32):
def __init__(
self,
vectorize_model: VectorizeModelABC,
batch_size: int = 32,
disable_generation: Optional[List[str]] = None,
):
"""
Initializes the BatchVectorizer with the specified vectorization model and batch size.
@ -53,6 +58,7 @@ class AffairBatchVectorizer(VectorizerABC):
self.vec_meta = self._init_vec_meta()
self.vectorize_model = vectorize_model
self.batch_size = batch_size
self.disable_generation = disable_generation
def _init_vec_meta(self):
"""
@ -95,7 +101,9 @@ class AffairBatchVectorizer(VectorizerABC):
properties.update(node.properties)
node_list.append((node, properties))
node_batch.append((node.label, properties.copy()))
generator = EmbeddingVectorGenerator(self.vectorize_model, self.vec_meta)
generator = EmbeddingVectorGenerator(
self.vectorize_model, self.vec_meta, self.disable_generation
)
generator.batch_generate(node_batch, self.batch_size)
for (node, properties), (_node_label, new_properties) in zip(
node_list, node_batch