From 6745cf1744901db29ceff3aba82f574b2b70d990 Mon Sep 17 00:00:00 2001 From: xionghuaidong Date: Thu, 24 Apr 2025 13:36:08 +0800 Subject: [PATCH] support disabling vector generation (#486) --- .../component/vectorizer/batch_vectorizer.py | 47 ++++++++++++++----- .../builder/affair_batch_vectorizer.py | 14 ++++-- 2 files changed, 47 insertions(+), 14 deletions(-) diff --git a/kag/builder/component/vectorizer/batch_vectorizer.py b/kag/builder/component/vectorizer/batch_vectorizer.py index 4bf982eb..6a18fb06 100644 --- a/kag/builder/component/vectorizer/batch_vectorizer.py +++ b/kag/builder/component/vectorizer/batch_vectorizer.py @@ -11,7 +11,7 @@ # or implied. import asyncio from collections import defaultdict -from typing import List +from typing import List, Optional from tenacity import stop_after_attempt, retry from kag.builder.model.sub_graph import SubGraph @@ -43,11 +43,15 @@ class EmbeddingVectorPlaceholder(object): class EmbeddingVectorManager(object): - def __init__(self): + def __init__(self, disable_generation=None): self._placeholders = [] + self._disable_generation = frozenset(disable_generation or []) - def get_placeholder(self, properties, vector_field): + def get_placeholder(self, label, properties, vector_field): for property_key, property_value in properties.items(): + disable_prop_key = f"{label}.{property_key}" + if disable_prop_key in self._disable_generation: + continue field_name = get_vector_field_name(property_key) if field_name != vector_field: continue @@ -126,13 +130,20 @@ class EmbeddingVectorManager(object): class EmbeddingVectorGenerator(object): - def __init__(self, vectorizer, vector_index_meta=None, extra_labels=("Entity",)): + def __init__( + self, + vectorizer, + vector_index_meta=None, + disable_generation=None, + extra_labels=("Entity",), + ): self._vectorizer = vectorizer self._extra_labels = extra_labels self._vector_index_meta = vector_index_meta or {} + self._disable_generation = disable_generation def batch_generate(self, node_batch, batch_size=32): - manager = EmbeddingVectorManager() + manager = EmbeddingVectorManager(self._disable_generation) vector_index_meta = self._vector_index_meta for node_item in node_batch: label, properties = node_item @@ -145,14 +156,16 @@ class EmbeddingVectorGenerator(object): for vector_field in vector_index_meta[label]: if vector_field in properties: continue - placeholder = manager.get_placeholder(properties, vector_field) + placeholder = manager.get_placeholder( + label, properties, vector_field + ) if placeholder is not None: properties[vector_field] = placeholder manager.batch_generate(self._vectorizer, batch_size) manager.patch() async def abatch_generate(self, node_batch, batch_size=32): - manager = EmbeddingVectorManager() + manager = EmbeddingVectorManager(self._disable_generation) vector_index_meta = self._vector_index_meta for node_item in node_batch: label, properties = node_item @@ -165,7 +178,9 @@ class EmbeddingVectorGenerator(object): for vector_field in vector_index_meta[label]: if vector_field in properties: continue - placeholder = manager.get_placeholder(properties, vector_field) + placeholder = manager.get_placeholder( + label, properties, vector_field + ) if placeholder is not None: properties[vector_field] = placeholder await manager.abatch_generate(self._vectorizer, batch_size) @@ -189,7 +204,12 @@ class BatchVectorizer(VectorizerABC): batch_size (int): The size of the batches in which to process the nodes. """ - def __init__(self, vectorize_model: VectorizeModelABC, batch_size: int = 32): + def __init__( + self, + vectorize_model: VectorizeModelABC, + batch_size: int = 32, + disable_generation: Optional[List[str]] = None, + ): """ Initializes the BatchVectorizer with the specified vectorization model and batch size. @@ -203,6 +223,7 @@ class BatchVectorizer(VectorizerABC): self.vec_meta = self._init_vec_meta() self.vectorize_model = vectorize_model self.batch_size = batch_size + self.disable_generation = disable_generation def _init_vec_meta(self): """ @@ -246,7 +267,9 @@ class BatchVectorizer(VectorizerABC): properties.update(node.properties) node_list.append((node, properties)) node_batch.append((node.label, properties.copy())) - generator = EmbeddingVectorGenerator(self.vectorize_model, self.vec_meta) + generator = EmbeddingVectorGenerator( + self.vectorize_model, self.vec_meta, self.disable_generation + ) generator.batch_generate(node_batch, self.batch_size) for (node, properties), (_node_label, new_properties) in zip( node_list, node_batch @@ -277,7 +300,9 @@ class BatchVectorizer(VectorizerABC): properties.update(node.properties) node_list.append((node, properties)) node_batch.append((node.label, properties.copy())) - generator = EmbeddingVectorGenerator(self.vectorize_model, self.vec_meta) + generator = EmbeddingVectorGenerator( + self.vectorize_model, self.vec_meta, self.disable_generation + ) await generator.abatch_generate(node_batch, self.batch_size) for (node, properties), (_node_label, new_properties) in zip( node_list, node_batch diff --git a/kag/open_benchmark/AffairQA/builder/affair_batch_vectorizer.py b/kag/open_benchmark/AffairQA/builder/affair_batch_vectorizer.py index 1810dc2d..19a0dd37 100644 --- a/kag/open_benchmark/AffairQA/builder/affair_batch_vectorizer.py +++ b/kag/open_benchmark/AffairQA/builder/affair_batch_vectorizer.py @@ -10,7 +10,7 @@ # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. from collections import defaultdict -from typing import List +from typing import List, Optional from kag.builder.component.vectorizer.batch_vectorizer import EmbeddingVectorGenerator from tenacity import stop_after_attempt, retry @@ -39,7 +39,12 @@ class AffairBatchVectorizer(VectorizerABC): batch_size (int): The size of the batches in which to process the nodes. """ - def __init__(self, vectorize_model: VectorizeModelABC, batch_size: int = 32): + def __init__( + self, + vectorize_model: VectorizeModelABC, + batch_size: int = 32, + disable_generation: Optional[List[str]] = None, + ): """ Initializes the BatchVectorizer with the specified vectorization model and batch size. @@ -53,6 +58,7 @@ class AffairBatchVectorizer(VectorizerABC): self.vec_meta = self._init_vec_meta() self.vectorize_model = vectorize_model self.batch_size = batch_size + self.disable_generation = disable_generation def _init_vec_meta(self): """ @@ -95,7 +101,9 @@ class AffairBatchVectorizer(VectorizerABC): properties.update(node.properties) node_list.append((node, properties)) node_batch.append((node.label, properties.copy())) - generator = EmbeddingVectorGenerator(self.vectorize_model, self.vec_meta) + generator = EmbeddingVectorGenerator( + self.vectorize_model, self.vec_meta, self.disable_generation + ) generator.batch_generate(node_batch, self.batch_size) for (node, properties), (_node_label, new_properties) in zip( node_list, node_batch