mirror of
https://github.com/OpenSPG/KAG.git
synced 2025-06-27 03:20:08 +00:00

* add async return * turn vectorize model to singleton * turn vectorize model to singleton * add schema type linker * fix black format * bugfix vectorize --------- Co-authored-by: 钟书 <zhongshu.zzs@antgroup.com>
150 lines
4.8 KiB
Python
150 lines
4.8 KiB
Python
# Copyright 2023 OpenSPG Authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
|
# in compliance with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software distributed under the License
|
|
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
|
# or implied.
|
|
|
|
import io
|
|
import os
|
|
import tarfile
|
|
|
|
import requests
|
|
import logging
|
|
import asyncio
|
|
import threading
|
|
|
|
from tenacity import retry, stop_after_attempt
|
|
from kag.common.rate_limiter import RATE_LIMITER_MANGER
|
|
from kag.common.registry import Registrable
|
|
from typing import Union, Iterable
|
|
|
|
EmbeddingVector = Iterable[float]
|
|
logger = logging.getLogger()
|
|
|
|
|
|
@Registrable.register("vectorize_model")
|
|
class VectorizeModelABC(Registrable):
|
|
"""
|
|
An abstract base class that defines the interface for converting text into embedding vectors.
|
|
"""
|
|
|
|
_instances = {}
|
|
_lock = threading.Lock()
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
key = cls.generate_key(*args, **kwargs)
|
|
|
|
if key in cls._instances:
|
|
return cls._instances[key]
|
|
|
|
instance = super().__new__(cls)
|
|
cls._instances[key] = instance
|
|
return instance
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
vector_dimensions: int = None,
|
|
max_rate: float = 1000,
|
|
time_period: float = 1,
|
|
):
|
|
"""
|
|
Initializes the VectorizeModelABC instance.
|
|
|
|
Args:
|
|
vector_dimensions (int, optional): The number of dimensions for the embedding vectors. Defaults to None.
|
|
"""
|
|
self._vector_dimensions = vector_dimensions
|
|
self.limiter = RATE_LIMITER_MANGER.get_rate_limiter(name, max_rate, time_period)
|
|
|
|
def _download_model(self, path, url):
|
|
"""
|
|
Downloads a model from a specified URL and extracts it to a given path.
|
|
|
|
Args:
|
|
path (str): The directory path to save the downloaded model.
|
|
url (str): The URL from which to download the model.
|
|
|
|
Raises:
|
|
RuntimeError: If the model configuration file is not found at the specified path.
|
|
"""
|
|
logger.info(f"download model from:\n{url} to:\n{path}")
|
|
res = requests.get(url)
|
|
with io.BytesIO(res.content) as fileobj:
|
|
with tarfile.open(fileobj=fileobj) as tar:
|
|
tar.extractall(path=path)
|
|
config_path = os.path.join(path, "config.json")
|
|
if not os.path.isfile(config_path):
|
|
message = f"model config not found at {config_path!r}, url {url!r} specified an invalid model"
|
|
raise RuntimeError(message)
|
|
|
|
def get_vector_dimensions(self):
|
|
"""
|
|
Retrieves the dimension of the generated embedding vectors.
|
|
|
|
Returns:
|
|
int: The number of dimensions for the embedding vectors.
|
|
|
|
Raises:
|
|
RuntimeError: If the embedding service is not available.
|
|
"""
|
|
if hasattr(self, "_vector_dimensions"):
|
|
return int(self._vector_dimensions)
|
|
try:
|
|
example_input = "This is a test."
|
|
example_vector = self.vectorize(example_input)
|
|
self._vector_dimensions = len(example_vector)
|
|
return self._vector_dimensions
|
|
|
|
except Exception as ex:
|
|
message = "the embedding service is not available"
|
|
raise RuntimeError(message) from ex
|
|
|
|
@classmethod
|
|
def generate_key(cls, *args, **kwargs) -> str:
|
|
return f"{cls}"
|
|
|
|
@retry(stop=stop_after_attempt(3), reraise=True)
|
|
def vectorize(
|
|
self, texts: Union[str, Iterable[str]]
|
|
) -> Union[EmbeddingVector, Iterable[EmbeddingVector]]:
|
|
"""
|
|
Vectorizes text(s) into embedding vector(s).
|
|
|
|
Args:
|
|
texts (Union[str, Iterable[str]]): The text or texts to vectorize.
|
|
|
|
Returns:
|
|
Union[EmbeddingVector, Iterable[EmbeddingVector]]: The embedding vector(s) of the text(s).
|
|
|
|
Raises:
|
|
NotImplementedError: This method must be implemented by subclasses.
|
|
"""
|
|
message = "abstract method vectorize is not implemented"
|
|
raise NotImplementedError(message)
|
|
|
|
@retry(stop=stop_after_attempt(3), reraise=True)
|
|
async def avectorize(
|
|
self, texts: Union[str, Iterable[str]]
|
|
) -> Union[EmbeddingVector, Iterable[EmbeddingVector]]:
|
|
"""
|
|
Vectorizes text(s) into embedding vector(s).
|
|
|
|
Args:
|
|
texts (Union[str, Iterable[str]]): The text or texts to vectorize.
|
|
|
|
Returns:
|
|
Union[EmbeddingVector, Iterable[EmbeddingVector]]: The embedding vector(s) of the text(s).
|
|
|
|
Raises:
|
|
NotImplementedError: This method must be implemented by subclasses.
|
|
"""
|
|
|
|
async with self.limiter:
|
|
return await asyncio.to_thread(lambda: self.vectorize(texts))
|