support sentence_transformers in LM-Cocktail

This commit is contained in:
shitao 2023-11-29 17:33:23 +08:00
parent 0c9cc0641c
commit 594fcbceff
3 changed files with 18 additions and 0 deletions

2
.gitignore vendored
View File

@ -131,6 +131,8 @@ Untitled.ipynb
try.py
update_model_card.py
model_card.md
pic.py
pic2.py
# Pyre type checker
.pyre/

View File

@ -4,10 +4,23 @@ import numpy as np
from typing import List, Dict, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer, models
from .utils import load_model, get_model_param_list, merge_param, compute_weights
def save_ckpt_for_sentence_transformers(ckpt_dir, pooling_mode: str = 'cls', normalized: bool = True):
word_embedding_model = models.Transformer(ckpt_dir)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
pooling_mode=pooling_mode)
if normalized:
normalized_layer = models.Normalize()
model = SentenceTransformer(modules=[word_embedding_model, pooling_model, normalized_layer],
device='cpu')
else:
model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device='cpu')
model.save(ckpt_dir)
def mix_models(model_names_or_paths: List[str],
model_type: str,
@ -45,6 +58,9 @@ def mix_models(model_names_or_paths: List[str],
tokenizer = AutoTokenizer.from_pretrained(model_names_or_paths[0])
tokenizer.save_pretrained(output_path)
if model_type == "encoder":
print(f"Transform the model to the format of 'sentence_transformers' (pooling_method='cls', normalized=True)")
save_ckpt_for_sentence_transformers(ckpt_dir=output_path)
return model

BIN
agnews.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB