mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-11-17 18:45:57 +00:00
75 lines
2.7 KiB
Python
75 lines
2.7 KiB
Python
import os
|
|
from typing import Optional, List
|
|
from dataclasses import dataclass, field
|
|
from sentence_transformers import models, SentenceTransformer
|
|
from transformers import HfArgumentParser
|
|
|
|
|
|
def convert_ours_ckpt_to_sentence_transformer(src_dir, dest_dir, pooling_method: List[str] = ['cls'], dense_metric: str="cos"):
|
|
assert os.path.exists(src_dir), f"Make sure the encoder path {src_dir} is valid on disk!"
|
|
assert "decoder" not in pooling_method, f"Pooling method 'decode' cannot be saved as sentence_transformers because it uses the decoder stack to produce sentence embedding."
|
|
if dest_dir is None:
|
|
dest_dir = src_dir
|
|
|
|
print(f"loading model from {src_dir} and saving the sentence_transformer model at {dest_dir}...")
|
|
|
|
word_embedding_model = models.Transformer(src_dir)
|
|
modules = [word_embedding_model]
|
|
ndim = word_embedding_model.get_word_embedding_dimension()
|
|
|
|
if "cls" in pooling_method:
|
|
pooling_model = models.Pooling(ndim, pooling_mode="cls")
|
|
pooling_method.remove("cls")
|
|
elif "mean" in pooling_method:
|
|
pooling_model = models.Pooling(ndim, pooling_mode="mean")
|
|
pooling_method.remove("mean")
|
|
else:
|
|
raise NotImplementedError(f"Fail to find cls or mean in pooling_method {pooling_method}!")
|
|
|
|
modules.append(pooling_model)
|
|
|
|
if "dense" in pooling_method:
|
|
modules.append(models.Dense(ndim, ndim, bias=False))
|
|
pooling_method.remove("dense")
|
|
|
|
assert len(pooling_method) == 0, f"Found unused pooling_method {pooling_method}!"
|
|
|
|
if dense_metric == "cos":
|
|
normalize_layer = models.Normalize()
|
|
modules.append(normalize_layer)
|
|
|
|
model = SentenceTransformer(modules=modules, device='cpu')
|
|
model.save(dest_dir)
|
|
|
|
|
|
@dataclass
|
|
class Args:
|
|
encoder: Optional[str] = field(
|
|
default=None,
|
|
metadata={'help': 'Path to the encoder model.'}
|
|
)
|
|
output_dir: Optional[str] = field(
|
|
default=None,
|
|
metadata={'help': 'Path to the output sentence_transformer model.'}
|
|
)
|
|
pooling_method: List[str] = field(
|
|
default_factory=lambda: ["cls"],
|
|
metadata={'help': 'Pooling methods to aggregate token embeddings for a sequence embedding. {cls, mean, dense, decoder}'}
|
|
)
|
|
dense_metric: str = field(
|
|
default="cos",
|
|
metadata={'help': 'What type of metric for dense retrieval? ip, l2, or cos.'}
|
|
)
|
|
model_cache_dir: Optional[str] = field(
|
|
default=None,
|
|
metadata={'help': 'Cache folder for huggingface transformers.'}
|
|
)
|
|
|
|
def __post_init__(self):
|
|
convert_ours_ckpt_to_sentence_transformer(self.encoder, self.output_dir, self.pooling_method, self.dense_metric)
|
|
|
|
if __name__ == "__main__":
|
|
parser = HfArgumentParser([Args])
|
|
args, = parser.parse_args_into_dataclasses()
|
|
|