diff --git a/.gitignore b/.gitignore index 80a5373..9b4fde6 100644 --- a/.gitignore +++ b/.gitignore @@ -131,6 +131,8 @@ Untitled.ipynb try.py update_model_card.py model_card.md +pic.py +pic2.py # Pyre type checker .pyre/ diff --git a/LM_Cocktail/LM_Cocktail/cocktail.py b/LM_Cocktail/LM_Cocktail/cocktail.py index 2c60640..2ca8294 100644 --- a/LM_Cocktail/LM_Cocktail/cocktail.py +++ b/LM_Cocktail/LM_Cocktail/cocktail.py @@ -4,10 +4,23 @@ import numpy as np from typing import List, Dict, Any from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel +from sentence_transformers import SentenceTransformer, models from .utils import load_model, get_model_param_list, merge_param, compute_weights +def save_ckpt_for_sentence_transformers(ckpt_dir, pooling_mode: str = 'cls', normalized: bool = True): + word_embedding_model = models.Transformer(ckpt_dir) + pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), + pooling_mode=pooling_mode) + if normalized: + normalized_layer = models.Normalize() + model = SentenceTransformer(modules=[word_embedding_model, pooling_model, normalized_layer], + device='cpu') + else: + model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device='cpu') + model.save(ckpt_dir) + def mix_models(model_names_or_paths: List[str], model_type: str, @@ -45,6 +58,9 @@ def mix_models(model_names_or_paths: List[str], tokenizer = AutoTokenizer.from_pretrained(model_names_or_paths[0]) tokenizer.save_pretrained(output_path) + if model_type == "encoder": + print(f"Transform the model to the format of 'sentence_transformers' (pooling_method='cls', normalized=True)") + save_ckpt_for_sentence_transformers(ckpt_dir=output_path) return model diff --git a/agnews.png b/agnews.png new file mode 100644 index 0000000..88425cb Binary files /dev/null and b/agnews.png differ