mirror of
https://github.com/OpenSPG/KAG.git
synced 2025-06-27 03:20:08 +00:00
support disabling vector generation (#486)
This commit is contained in:
parent
a7fd51d138
commit
6745cf1744
@ -11,7 +11,7 @@
|
|||||||
# or implied.
|
# or implied.
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
from tenacity import stop_after_attempt, retry
|
from tenacity import stop_after_attempt, retry
|
||||||
|
|
||||||
from kag.builder.model.sub_graph import SubGraph
|
from kag.builder.model.sub_graph import SubGraph
|
||||||
@ -43,11 +43,15 @@ class EmbeddingVectorPlaceholder(object):
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingVectorManager(object):
|
class EmbeddingVectorManager(object):
|
||||||
def __init__(self):
|
def __init__(self, disable_generation=None):
|
||||||
self._placeholders = []
|
self._placeholders = []
|
||||||
|
self._disable_generation = frozenset(disable_generation or [])
|
||||||
|
|
||||||
def get_placeholder(self, properties, vector_field):
|
def get_placeholder(self, label, properties, vector_field):
|
||||||
for property_key, property_value in properties.items():
|
for property_key, property_value in properties.items():
|
||||||
|
disable_prop_key = f"{label}.{property_key}"
|
||||||
|
if disable_prop_key in self._disable_generation:
|
||||||
|
continue
|
||||||
field_name = get_vector_field_name(property_key)
|
field_name = get_vector_field_name(property_key)
|
||||||
if field_name != vector_field:
|
if field_name != vector_field:
|
||||||
continue
|
continue
|
||||||
@ -126,13 +130,20 @@ class EmbeddingVectorManager(object):
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingVectorGenerator(object):
|
class EmbeddingVectorGenerator(object):
|
||||||
def __init__(self, vectorizer, vector_index_meta=None, extra_labels=("Entity",)):
|
def __init__(
|
||||||
|
self,
|
||||||
|
vectorizer,
|
||||||
|
vector_index_meta=None,
|
||||||
|
disable_generation=None,
|
||||||
|
extra_labels=("Entity",),
|
||||||
|
):
|
||||||
self._vectorizer = vectorizer
|
self._vectorizer = vectorizer
|
||||||
self._extra_labels = extra_labels
|
self._extra_labels = extra_labels
|
||||||
self._vector_index_meta = vector_index_meta or {}
|
self._vector_index_meta = vector_index_meta or {}
|
||||||
|
self._disable_generation = disable_generation
|
||||||
|
|
||||||
def batch_generate(self, node_batch, batch_size=32):
|
def batch_generate(self, node_batch, batch_size=32):
|
||||||
manager = EmbeddingVectorManager()
|
manager = EmbeddingVectorManager(self._disable_generation)
|
||||||
vector_index_meta = self._vector_index_meta
|
vector_index_meta = self._vector_index_meta
|
||||||
for node_item in node_batch:
|
for node_item in node_batch:
|
||||||
label, properties = node_item
|
label, properties = node_item
|
||||||
@ -145,14 +156,16 @@ class EmbeddingVectorGenerator(object):
|
|||||||
for vector_field in vector_index_meta[label]:
|
for vector_field in vector_index_meta[label]:
|
||||||
if vector_field in properties:
|
if vector_field in properties:
|
||||||
continue
|
continue
|
||||||
placeholder = manager.get_placeholder(properties, vector_field)
|
placeholder = manager.get_placeholder(
|
||||||
|
label, properties, vector_field
|
||||||
|
)
|
||||||
if placeholder is not None:
|
if placeholder is not None:
|
||||||
properties[vector_field] = placeholder
|
properties[vector_field] = placeholder
|
||||||
manager.batch_generate(self._vectorizer, batch_size)
|
manager.batch_generate(self._vectorizer, batch_size)
|
||||||
manager.patch()
|
manager.patch()
|
||||||
|
|
||||||
async def abatch_generate(self, node_batch, batch_size=32):
|
async def abatch_generate(self, node_batch, batch_size=32):
|
||||||
manager = EmbeddingVectorManager()
|
manager = EmbeddingVectorManager(self._disable_generation)
|
||||||
vector_index_meta = self._vector_index_meta
|
vector_index_meta = self._vector_index_meta
|
||||||
for node_item in node_batch:
|
for node_item in node_batch:
|
||||||
label, properties = node_item
|
label, properties = node_item
|
||||||
@ -165,7 +178,9 @@ class EmbeddingVectorGenerator(object):
|
|||||||
for vector_field in vector_index_meta[label]:
|
for vector_field in vector_index_meta[label]:
|
||||||
if vector_field in properties:
|
if vector_field in properties:
|
||||||
continue
|
continue
|
||||||
placeholder = manager.get_placeholder(properties, vector_field)
|
placeholder = manager.get_placeholder(
|
||||||
|
label, properties, vector_field
|
||||||
|
)
|
||||||
if placeholder is not None:
|
if placeholder is not None:
|
||||||
properties[vector_field] = placeholder
|
properties[vector_field] = placeholder
|
||||||
await manager.abatch_generate(self._vectorizer, batch_size)
|
await manager.abatch_generate(self._vectorizer, batch_size)
|
||||||
@ -189,7 +204,12 @@ class BatchVectorizer(VectorizerABC):
|
|||||||
batch_size (int): The size of the batches in which to process the nodes.
|
batch_size (int): The size of the batches in which to process the nodes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, vectorize_model: VectorizeModelABC, batch_size: int = 32):
|
def __init__(
|
||||||
|
self,
|
||||||
|
vectorize_model: VectorizeModelABC,
|
||||||
|
batch_size: int = 32,
|
||||||
|
disable_generation: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the BatchVectorizer with the specified vectorization model and batch size.
|
Initializes the BatchVectorizer with the specified vectorization model and batch size.
|
||||||
|
|
||||||
@ -203,6 +223,7 @@ class BatchVectorizer(VectorizerABC):
|
|||||||
self.vec_meta = self._init_vec_meta()
|
self.vec_meta = self._init_vec_meta()
|
||||||
self.vectorize_model = vectorize_model
|
self.vectorize_model = vectorize_model
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
self.disable_generation = disable_generation
|
||||||
|
|
||||||
def _init_vec_meta(self):
|
def _init_vec_meta(self):
|
||||||
"""
|
"""
|
||||||
@ -246,7 +267,9 @@ class BatchVectorizer(VectorizerABC):
|
|||||||
properties.update(node.properties)
|
properties.update(node.properties)
|
||||||
node_list.append((node, properties))
|
node_list.append((node, properties))
|
||||||
node_batch.append((node.label, properties.copy()))
|
node_batch.append((node.label, properties.copy()))
|
||||||
generator = EmbeddingVectorGenerator(self.vectorize_model, self.vec_meta)
|
generator = EmbeddingVectorGenerator(
|
||||||
|
self.vectorize_model, self.vec_meta, self.disable_generation
|
||||||
|
)
|
||||||
generator.batch_generate(node_batch, self.batch_size)
|
generator.batch_generate(node_batch, self.batch_size)
|
||||||
for (node, properties), (_node_label, new_properties) in zip(
|
for (node, properties), (_node_label, new_properties) in zip(
|
||||||
node_list, node_batch
|
node_list, node_batch
|
||||||
@ -277,7 +300,9 @@ class BatchVectorizer(VectorizerABC):
|
|||||||
properties.update(node.properties)
|
properties.update(node.properties)
|
||||||
node_list.append((node, properties))
|
node_list.append((node, properties))
|
||||||
node_batch.append((node.label, properties.copy()))
|
node_batch.append((node.label, properties.copy()))
|
||||||
generator = EmbeddingVectorGenerator(self.vectorize_model, self.vec_meta)
|
generator = EmbeddingVectorGenerator(
|
||||||
|
self.vectorize_model, self.vec_meta, self.disable_generation
|
||||||
|
)
|
||||||
await generator.abatch_generate(node_batch, self.batch_size)
|
await generator.abatch_generate(node_batch, self.batch_size)
|
||||||
for (node, properties), (_node_label, new_properties) in zip(
|
for (node, properties), (_node_label, new_properties) in zip(
|
||||||
node_list, node_batch
|
node_list, node_batch
|
||||||
|
@ -10,7 +10,7 @@
|
|||||||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
# or implied.
|
# or implied.
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
from kag.builder.component.vectorizer.batch_vectorizer import EmbeddingVectorGenerator
|
from kag.builder.component.vectorizer.batch_vectorizer import EmbeddingVectorGenerator
|
||||||
from tenacity import stop_after_attempt, retry
|
from tenacity import stop_after_attempt, retry
|
||||||
|
|
||||||
@ -39,7 +39,12 @@ class AffairBatchVectorizer(VectorizerABC):
|
|||||||
batch_size (int): The size of the batches in which to process the nodes.
|
batch_size (int): The size of the batches in which to process the nodes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, vectorize_model: VectorizeModelABC, batch_size: int = 32):
|
def __init__(
|
||||||
|
self,
|
||||||
|
vectorize_model: VectorizeModelABC,
|
||||||
|
batch_size: int = 32,
|
||||||
|
disable_generation: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the BatchVectorizer with the specified vectorization model and batch size.
|
Initializes the BatchVectorizer with the specified vectorization model and batch size.
|
||||||
|
|
||||||
@ -53,6 +58,7 @@ class AffairBatchVectorizer(VectorizerABC):
|
|||||||
self.vec_meta = self._init_vec_meta()
|
self.vec_meta = self._init_vec_meta()
|
||||||
self.vectorize_model = vectorize_model
|
self.vectorize_model = vectorize_model
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
self.disable_generation = disable_generation
|
||||||
|
|
||||||
def _init_vec_meta(self):
|
def _init_vec_meta(self):
|
||||||
"""
|
"""
|
||||||
@ -95,7 +101,9 @@ class AffairBatchVectorizer(VectorizerABC):
|
|||||||
properties.update(node.properties)
|
properties.update(node.properties)
|
||||||
node_list.append((node, properties))
|
node_list.append((node, properties))
|
||||||
node_batch.append((node.label, properties.copy()))
|
node_batch.append((node.label, properties.copy()))
|
||||||
generator = EmbeddingVectorGenerator(self.vectorize_model, self.vec_meta)
|
generator = EmbeddingVectorGenerator(
|
||||||
|
self.vectorize_model, self.vec_meta, self.disable_generation
|
||||||
|
)
|
||||||
generator.batch_generate(node_batch, self.batch_size)
|
generator.batch_generate(node_batch, self.batch_size)
|
||||||
for (node, properties), (_node_label, new_properties) in zip(
|
for (node, properties), (_node_label, new_properties) in zip(
|
||||||
node_list, node_batch
|
node_list, node_batch
|
||||||
|
Loading…
x
Reference in New Issue
Block a user