mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
129 lines
4.9 KiB
Python
129 lines
4.9 KiB
Python
import logging
|
|
from typing import Tuple
|
|
from transformers import (
|
|
AutoModel, AutoConfig,
|
|
AutoTokenizer, PreTrainedTokenizer
|
|
)
|
|
|
|
from FlagEmbedding.abc.finetune.embedder import AbsEmbedderRunner, AbsEmbedderModel, EmbedderTrainerCallbackForDataRefresh, AbsEmbedderModelArguments, AbsEmbedderTrainingArguments
|
|
from modeling import BiIREmbedderModel
|
|
from trainer import IREmbedderTrainer
|
|
from dataset import (
|
|
IREmbedderTrainDataset, IREmbedderCollator,
|
|
IREmbedderSameDatasetTrainDataset, IREmbedderSameDatasetCollator
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class IREmbedderRunner(AbsEmbedderRunner):
|
|
"""
|
|
Finetune Runner for base embedding models.
|
|
"""
|
|
|
|
def load_train_dataset(self):
|
|
|
|
if self.data_args.same_dataset_within_batch:
|
|
train_dataset = IREmbedderSameDatasetTrainDataset(
|
|
args=self.data_args,
|
|
default_batch_size=self.training_args.per_device_train_batch_size,
|
|
seed=self.training_args.seed,
|
|
tokenizer=self.tokenizer,
|
|
process_index=self.training_args.process_index,
|
|
num_processes=self.training_args.world_size
|
|
)
|
|
self.training_args.per_device_train_batch_size = 1
|
|
self.training_args.dataloader_num_workers = 0 # avoid multi-processing
|
|
else:
|
|
train_dataset = IREmbedderTrainDataset(
|
|
args=self.data_args,
|
|
tokenizer=self.tokenizer
|
|
)
|
|
return train_dataset
|
|
|
|
def load_data_collator(self):
|
|
if self.data_args.same_dataset_within_batch:
|
|
EmbedCollator = IREmbedderSameDatasetCollator
|
|
else:
|
|
EmbedCollator = IREmbedderTrainDataset
|
|
|
|
data_collator = EmbedCollator(
|
|
tokenizer=self.tokenizer,
|
|
query_max_len=self.data_args.query_max_len,
|
|
passage_max_len=self.data_args.passage_max_len,
|
|
sub_batch_size=self.training_args.sub_batch_size,
|
|
pad_to_multiple_of=self.data_args.pad_to_multiple_of,
|
|
padding=True,
|
|
return_tensors="pt"
|
|
)
|
|
return data_collator
|
|
|
|
def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsEmbedderModel, AbsEmbedderModel]:
|
|
"""Load tokenizer and model.
|
|
|
|
Returns:
|
|
Tuple[PreTrainedTokenizer, AbsEmbedderModel]: Tokenizer and model instances.
|
|
"""
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
self.model_args.model_name_or_path,
|
|
cache_dir=self.model_args.cache_dir,
|
|
token=self.model_args.token,
|
|
trust_remote_code=self.model_args.trust_remote_code
|
|
)
|
|
base_model = AutoModel.from_pretrained(
|
|
self.model_args.model_name_or_path,
|
|
cache_dir=self.model_args.cache_dir,
|
|
token=self.model_args.token,
|
|
trust_remote_code=self.model_args.trust_remote_code
|
|
)
|
|
|
|
num_labels = 1
|
|
config = AutoConfig.from_pretrained(
|
|
self.model_args.config_name if self.model_args.config_name else self.model_args.model_name_or_path,
|
|
num_labels=num_labels,
|
|
cache_dir=self.model_args.cache_dir,
|
|
token=self.model_args.token,
|
|
trust_remote_code=self.model_args.trust_remote_code,
|
|
)
|
|
logger.info('Config: %s', config)
|
|
|
|
model = BiIREmbedderModel(
|
|
base_model,
|
|
tokenizer=tokenizer,
|
|
negatives_cross_device=self.training_args.negatives_cross_device,
|
|
temperature=self.training_args.temperature,
|
|
answer_temperature=self.training_args.answer_temperature,
|
|
sub_batch_size=self.training_args.sub_batch_size,
|
|
kd_loss_type=self.training_args.kd_loss_type,
|
|
sentence_pooling_method=self.training_args.sentence_pooling_method,
|
|
normalize_embeddings=self.training_args.normalize_embeddings,
|
|
normalize_answer=self.training_args.normalize_answer,
|
|
training_type=self.training_args.training_type
|
|
)
|
|
|
|
if self.training_args.gradient_checkpointing:
|
|
model.enable_input_require_grads()
|
|
|
|
if self.training_args.fix_position_embedding:
|
|
for k, v in model.named_parameters():
|
|
if "position_embeddings" in k:
|
|
logging.info(f"Freeze the parameters for {k}")
|
|
v.requires_grad = False
|
|
return tokenizer, model
|
|
|
|
def load_trainer(self) -> IREmbedderTrainer:
|
|
"""Load the trainer.
|
|
|
|
Returns:
|
|
IREmbedderTrainer: Loaded trainer instance.
|
|
"""
|
|
trainer = IREmbedderTrainer(
|
|
model=self.model,
|
|
args=self.training_args,
|
|
train_dataset=self.train_dataset,
|
|
data_collator=self.data_collator,
|
|
tokenizer=self.tokenizer
|
|
)
|
|
if self.data_args.same_dataset_within_batch:
|
|
trainer.add_callback(EmbedderTrainerCallbackForDataRefresh(self.train_dataset))
|
|
return trainer |