mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
support sentence_transformers in LM-Cocktail
This commit is contained in:
parent
0c9cc0641c
commit
594fcbceff
2
.gitignore
vendored
2
.gitignore
vendored
@ -131,6 +131,8 @@ Untitled.ipynb
|
||||
try.py
|
||||
update_model_card.py
|
||||
model_card.md
|
||||
pic.py
|
||||
pic2.py
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
@ -4,10 +4,23 @@ import numpy as np
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
|
||||
from sentence_transformers import SentenceTransformer, models
|
||||
|
||||
from .utils import load_model, get_model_param_list, merge_param, compute_weights
|
||||
|
||||
|
||||
def save_ckpt_for_sentence_transformers(ckpt_dir, pooling_mode: str = 'cls', normalized: bool = True):
|
||||
word_embedding_model = models.Transformer(ckpt_dir)
|
||||
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
|
||||
pooling_mode=pooling_mode)
|
||||
if normalized:
|
||||
normalized_layer = models.Normalize()
|
||||
model = SentenceTransformer(modules=[word_embedding_model, pooling_model, normalized_layer],
|
||||
device='cpu')
|
||||
else:
|
||||
model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device='cpu')
|
||||
model.save(ckpt_dir)
|
||||
|
||||
|
||||
def mix_models(model_names_or_paths: List[str],
|
||||
model_type: str,
|
||||
@ -45,6 +58,9 @@ def mix_models(model_names_or_paths: List[str],
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_names_or_paths[0])
|
||||
tokenizer.save_pretrained(output_path)
|
||||
|
||||
if model_type == "encoder":
|
||||
print(f"Transform the model to the format of 'sentence_transformers' (pooling_method='cls', normalized=True)")
|
||||
save_ckpt_for_sentence_transformers(ckpt_dir=output_path)
|
||||
return model
|
||||
|
||||
|
||||
|
BIN
agnews.png
Normal file
BIN
agnews.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 42 KiB |
Loading…
x
Reference in New Issue
Block a user