2024-10-27 20:28:57 +08:00

75 lines
2.7 KiB
Python

import os
from typing import Optional, List
from dataclasses import dataclass, field
from sentence_transformers import models, SentenceTransformer
from transformers import HfArgumentParser
def convert_ours_ckpt_to_sentence_transformer(src_dir, dest_dir, pooling_method: List[str] = ['cls'], dense_metric: str="cos"):
assert os.path.exists(src_dir), f"Make sure the encoder path {src_dir} is valid on disk!"
assert "decoder" not in pooling_method, f"Pooling method 'decode' cannot be saved as sentence_transformers because it uses the decoder stack to produce sentence embedding."
if dest_dir is None:
dest_dir = src_dir
print(f"loading model from {src_dir} and saving the sentence_transformer model at {dest_dir}...")
word_embedding_model = models.Transformer(src_dir)
modules = [word_embedding_model]
ndim = word_embedding_model.get_word_embedding_dimension()
if "cls" in pooling_method:
pooling_model = models.Pooling(ndim, pooling_mode="cls")
pooling_method.remove("cls")
elif "mean" in pooling_method:
pooling_model = models.Pooling(ndim, pooling_mode="mean")
pooling_method.remove("mean")
else:
raise NotImplementedError(f"Fail to find cls or mean in pooling_method {pooling_method}!")
modules.append(pooling_model)
if "dense" in pooling_method:
modules.append(models.Dense(ndim, ndim, bias=False))
pooling_method.remove("dense")
assert len(pooling_method) == 0, f"Found unused pooling_method {pooling_method}!"
if dense_metric == "cos":
normalize_layer = models.Normalize()
modules.append(normalize_layer)
model = SentenceTransformer(modules=modules, device='cpu')
model.save(dest_dir)
@dataclass
class Args:
encoder: Optional[str] = field(
default=None,
metadata={'help': 'Path to the encoder model.'}
)
output_dir: Optional[str] = field(
default=None,
metadata={'help': 'Path to the output sentence_transformer model.'}
)
pooling_method: List[str] = field(
default_factory=lambda: ["cls"],
metadata={'help': 'Pooling methods to aggregate token embeddings for a sequence embedding. {cls, mean, dense, decoder}'}
)
dense_metric: str = field(
default="cos",
metadata={'help': 'What type of metric for dense retrieval? ip, l2, or cos.'}
)
model_cache_dir: Optional[str] = field(
default=None,
metadata={'help': 'Cache folder for huggingface transformers.'}
)
def __post_init__(self):
convert_ours_ckpt_to_sentence_transformer(self.encoder, self.output_dir, self.pooling_method, self.dense_metric)
if __name__ == "__main__":
parser = HfArgumentParser([Args])
args, = parser.parse_args_into_dataclasses()