Upload code and example of fine-tuning bge-m3

This commit is contained in:
hanhainebula 2024-03-01 18:09:02 +08:00
parent 0ad0dee592
commit 8d50e27d30
11 changed files with 1026 additions and 11 deletions

View File

@ -1 +1,2 @@
from .modeling import BGEM3Model, BGEM3ForInference
from .modeling import BGEM3Model, BGEM3ForInference, EncoderOutput
from .trainer import BiTrainer

View File

@ -0,0 +1,93 @@
import os
from dataclasses import dataclass, field
from typing import Optional
from transformers import TrainingArguments
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
normlized: bool = field(default=True)
@dataclass
class DataArguments:
knowledge_distillation: bool = field(
default=False, metadata={"help": "Use knowledge distillation when `pos_scores` and `neg_scores` are in features of training data"}
)
train_data: str = field(
default=None, metadata={"help": "One or more paths to training data", "nargs": "+"}
)
cache_path: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the cached data"}
)
train_group_size: int = field(default=8)
query_max_len: int = field(
default=32,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
passage_max_len: int = field(
default=128,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_example_num_per_dataset: int = field(
default=None, metadata={"help": "the max number of examples for each dataset"}
)
query_instruction_for_retrieval: str= field(
default=None, metadata={"help": "instruction for query"}
)
passage_instruction_for_retrieval: str = field(
default=None, metadata={"help": "instruction for passage"}
)
same_task_within_batch: bool = field(
default=False, metadata={"help": "All samples in the same batch comes from the same task."}
)
shuffle_ratio: float = field(
default=0.0, metadata={"help": "The ratio of shuffling the text"}
)
def __post_init__(self):
for train_dir in self.train_data:
if not os.path.exists(train_dir):
raise FileNotFoundError(f"cannot find file: {train_dir}, please set a true path")
@dataclass
class RetrieverTrainingArguments(TrainingArguments):
negatives_cross_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
temperature: Optional[float] = field(default=0.02)
fix_position_embedding: bool = field(default=False, metadata={"help": "Freeze the parameters of position embeddings"})
sentence_pooling_method: str = field(default='cls', metadata={"help": "the pooling method, should be cls or mean"})
enable_sub_batch: bool = field(default=True, metadata={"help": "Freeze the parameters of position embeddings"})
unified_finetuning: bool = field(default=False, metadata={"help": "use unify fine-tuning"})
use_self_distill: bool = field(default=False, metadata={"help": "use self-distill when using unify fine-tuning"})
fix_encoder: bool = field(default=False, metadata={"help": "Freeze the parameters of encoder"})
colbert_dim: int = field(default=-1, metadata={"help": "Dim of colbert linear"})
self_distill_start_step: int = field(default=-1, metadata={"help": "Num of step when using self-distill"})

View File

@ -0,0 +1,303 @@
import math
import os.path
import random
from dataclasses import dataclass
from typing import List, Tuple
import torch
import numpy as np
import datasets
from pprint import pprint
from torch.utils.data import Dataset
from transformers import DataCollatorWithPadding
from transformers import PreTrainedTokenizer, BatchEncoding
import torch.distributed as dist
from .arguments import DataArguments
class SameDatasetTrainDataset(Dataset):
"""Dataset to yield a batch of data at one time. All samples in the same batch comes from the same task.
"""
def __init__(self, args: DataArguments, batch_size: int, seed: int, process_index: int=0, num_processes: int=1):
train_datasets = []
each_data_inxs = []
batch_size_inxs = []
pqloss_flag = []
cur_all_num = 0
SMALL_THRESHOLD = 200
DROP_THRESHOLD = 200
context_feat = datasets.Features({
'query': datasets.Value('string'),
'pos': datasets.Sequence(datasets.Value('string')),
'neg': datasets.Sequence(datasets.Value('string'))
})
context_feat_kd = datasets.Features({
'query': datasets.Value('string'),
'pos': datasets.Sequence(datasets.Value('string')),
'neg': datasets.Sequence(datasets.Value('string')),
'pos_scores': datasets.Sequence(datasets.Value('float')),
'neg_scores': datasets.Sequence(datasets.Value('float')),
})
assert isinstance(args.train_data, list) and len(args.train_data) >= 1
if dist.get_rank() == 0:
self.print_batch_size(batch_size=batch_size, train_group_size=args.train_group_size)
for data_dir in args.train_data:
if not os.path.isdir(data_dir):
raise FileNotFoundError(f"{data_dir} is a file, not a directionary")
small_datasets = []
small_batch_size = math.inf
# Add `parallel_` in `data_dir` to indicate that this dataset is parallel corpus
flag = 'parallel_' in data_dir
for file in os.listdir(data_dir):
if not (file.endswith('.json') or file.endswith('.jsonl')):
continue
file_path = os.path.join(data_dir, file)
if dist.get_rank() == 0:
print(f'loading data from {file_path} ...')
try:
temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=args.cache_path, features=context_feat)
except:
temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=args.cache_path, features=context_feat_kd)
if not args.knowledge_distillation:
temp_dataset = temp_dataset.remove_columns(['pos_scores', 'neg_scores'])
if len(temp_dataset) == 0:
continue
elif len(temp_dataset) < SMALL_THRESHOLD:
small_datasets.append(temp_dataset)
small_batch_size = min(small_batch_size, self.get_file_batch_size(file, batch_size, train_group_size=args.train_group_size))
else:
if args.max_example_num_per_dataset is not None and len(temp_dataset) > args.max_example_num_per_dataset:
temp_dataset = temp_dataset.select(
random.sample(list(range(len(temp_dataset))), args.max_example_num_per_dataset))
train_datasets.append(temp_dataset)
each_data_inxs.append(np.arange(len(temp_dataset)) + cur_all_num)
cur_all_num += len(temp_dataset)
batch_size_inxs.append(self.get_file_batch_size(file, batch_size, train_group_size=args.train_group_size))
pqloss_flag.append(flag)
if len(small_datasets) > 0:
small_dataset = datasets.concatenate_datasets(small_datasets)
if len(small_dataset) >= DROP_THRESHOLD:
train_datasets.append(small_dataset)
each_data_inxs.append(np.arange(len(small_dataset)) + cur_all_num)
cur_all_num += len(small_dataset)
batch_size_inxs.append(small_batch_size)
pqloss_flag.append(flag)
self.dataset = datasets.concatenate_datasets(train_datasets)
self.each_data_inxs = each_data_inxs
self.datasets_inxs = np.arange(len(each_data_inxs))
self.batch_size_inxs = batch_size_inxs
self.pqloss_flag = pqloss_flag
self.process_index = process_index
self.num_processes = num_processes
self.args = args
self.shuffle_ratio = args.shuffle_ratio
self.deterministic_generator = np.random.default_rng(seed)
self.step = 0
self.refresh_epoch()
def print_batch_size(self, batch_size: int, train_group_size: int):
length_list = ['0-500', '500-1000', '1000-2000', '2000-3000', '3000-4000', '4000-5000', '5000-6000', '6000-7000', '7000-inf']
batch_size_dict = {
k: self.get_file_batch_size(f"len-{k}.jsonl", batch_size, train_group_size) for k in length_list
}
batch_size_list = [
f'{length}:\t{batch_size_dict[length]}' for length in length_list
]
print("=========================")
print("Batch Size Dict:")
pprint(batch_size_list)
print("=========================")
@staticmethod
def get_file_batch_size(file: str, batch_size: int, train_group_size: int):
if train_group_size == 8:
# 80GB
if 'len-0-500.jsonl' in file:
return 48
elif 'len-500-1000.jsonl' in file:
return 32
elif 'len-1000-2000.jsonl' in file:
return 20
elif 'len-2000-3000.jsonl' in file:
return 18
elif 'len-3000-4000.jsonl' in file:
return 14
elif 'len-4000-5000.jsonl' in file:
return 14
elif 'len-5000-6000.jsonl' in file:
return 12
elif 'len-6000-7000.jsonl' in file:
return 10
elif 'len-7000-inf.jsonl' in file:
return 8
else:
return batch_size
elif train_group_size == 1:
# 80GB
if 'len-0-500.jsonl' in file:
return 700
elif 'len-500-1000.jsonl' in file:
return 570
elif 'len-1000-2000.jsonl' in file:
return 388
elif 'len-2000-3000.jsonl' in file:
return 288
elif 'len-3000-4000.jsonl' in file:
return 224
elif 'len-4000-5000.jsonl' in file:
return 180
elif 'len-5000-6000.jsonl' in file:
return 157
elif 'len-6000-7000.jsonl' in file:
return 128
elif 'len-7000-inf.jsonl' in file:
return 104
else:
return batch_size
else:
return batch_size
def refresh_epoch(self):
print(f'---------------------------*Rank {self.process_index}: refresh data---------------------------')
self.deterministic_generator.shuffle(self.datasets_inxs)
# Dynamically adjust batch size
batch_datas = []
for dataset_inx in self.datasets_inxs:
self.deterministic_generator.shuffle(self.each_data_inxs[dataset_inx])
cur_batch_size = self.batch_size_inxs[dataset_inx]*self.num_processes
flag = self.pqloss_flag[dataset_inx]
for start_index in range(0, len(self.each_data_inxs[dataset_inx]), cur_batch_size):
# judge the last batch's length
if len(self.each_data_inxs[dataset_inx]) - start_index < 2 * self.num_processes:
break
batch_datas.append((self.each_data_inxs[dataset_inx][start_index:start_index+cur_batch_size], flag))
self.deterministic_generator.shuffle(batch_datas)
self.batch_datas = batch_datas
self.step = 0
def __getitem__(self, _):
batch_indices, pqloss_flag = self.batch_datas[self.step]
cur_batch_size = int(len(batch_indices) / self.num_processes)
batch_indices = batch_indices[self.process_index * cur_batch_size: (self.process_index + 1) * cur_batch_size]
batch_data = self.dataset[batch_indices]
self.step += 1
queries, passages, teacher_scores = self.create_batch_data(batch_raw_data=batch_data)
# print('rank, step, flag, query, passage:', dist.get_rank(), self.step, pqloss_flag, queries, passages)
return queries, passages, teacher_scores, pqloss_flag
def shuffle_text(self, text):
if self.shuffle_ratio > 0 and len(text) > 100 and random.random() < self.shuffle_ratio:
split_text = []
chunk_size = len(text)//3 + 1
for i in range(0, len(split_text), chunk_size):
split_text.append(text[i:i+chunk_size])
random.shuffle(split_text)
return " ".join(split_text)
else:
return text
def create_batch_data(self, batch_raw_data):
queries, passages = [], []
teacher_scores = []
for i in range(len(batch_raw_data['query'])):
queries.append(batch_raw_data['query'][i])
pos_inx = random.choice(list(range(len(batch_raw_data['pos'][i]))))
passages.append(self.shuffle_text(batch_raw_data['pos'][i][pos_inx]))
if 'pos_scores' in batch_raw_data and batch_raw_data['pos_scores'][i] is not None:
teacher_scores.append(batch_raw_data['pos_scores'][i][pos_inx])
neg_inx_set = list(range(len(batch_raw_data['neg'][i])))
if len(batch_raw_data['neg'][i]) < self.args.train_group_size - 1:
num = math.ceil((self.args.train_group_size - 1) / len(batch_raw_data['neg'][i]))
neg_inxs = random.sample(neg_inx_set * num, self.args.train_group_size - 1)
else:
neg_inxs = random.sample(neg_inx_set, self.args.train_group_size - 1)
if 'neg_scores' in batch_raw_data and batch_raw_data['neg_scores'][i] is not None:
neg_scores = [(x, batch_raw_data['neg_scores'][i][x]) for x in neg_inxs]
neg_scores = sorted(neg_scores, key=lambda x:x[1], reverse=True)
neg_inxs = [x[0] for x in neg_scores]
teacher_scores.extend([x[1] for x in neg_scores])
negs = [batch_raw_data['neg'][i][x] for x in neg_inxs]
passages.extend(negs)
if len(teacher_scores) > 0 and len(passages) > 0:
assert len(teacher_scores) == len(passages)
if self.args.query_instruction_for_retrieval is not None:
queries = [self.args.query_instruction_for_retrieval+q for q in queries]
if self.args.passage_instruction_for_retrieval is not None:
passages = [self.args.passage_instruction_for_retrieval+p for p in passages]
if len(teacher_scores) == 0:
teacher_scores = None
return queries, passages, teacher_scores
def __len__(self):
return len(self.batch_datas) * self.num_processes
@dataclass
class EmbedCollator(DataCollatorWithPadding):
"""
Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
and pass batch separately to the actual collator.
Abstract out data detail for the model.
"""
query_max_len: int = 32
passage_max_len: int = 128
def __call__(self, features):
query = [f[0] for f in features]
passage = [f[1] for f in features]
teacher_scores = None
if len(features[0]) > 2:
teacher_scores = [f[2] for f in features]
if teacher_scores[0] is None:
teacher_scores = None
else:
teacher_scores = torch.FloatTensor(teacher_scores)
flag = None
if len(features[0]) == 4:
flag = [f[3] for f in features][0]
if isinstance(query[0], list):
query = sum(query, [])
if isinstance(passage[0], list):
passage = sum(passage, [])
q_collated = self.tokenizer(
query,
# padding='max_length', # used for adjusting the batch size in `get_file_batch_size()`
padding=True,
truncation=True,
max_length=self.query_max_len,
return_tensors="pt",
)
d_collated = self.tokenizer(
passage,
# padding='max_length', # used for adjusting the batch size in `get_file_batch_size()`
padding=True,
truncation=True,
max_length=self.passage_max_len,
return_tensors="pt",
)
if teacher_scores is not None:
teacher_scores = teacher_scores.reshape((len(q_collated['input_ids']), -1))
return {"query": q_collated, "passage": d_collated, "teacher_scores": teacher_scores, "bi_directions": flag}

View File

@ -34,7 +34,7 @@ class BGEM3Model(nn.Module):
unified_finetuning: bool = True,
use_self_distill: bool = False,
colbert_dim: int = -1,
ensemble_distill_start_step: int = -1,
self_distill_start_step: int = -1,
):
super().__init__()
self.load_model(model_name, colbert_dim=colbert_dim)
@ -51,7 +51,7 @@ class BGEM3Model(nn.Module):
self.enable_sub_batch = enable_sub_batch
self.temperature = temperature
self.use_self_distill = use_self_distill
self.ensemble_distill_start_step = ensemble_distill_start_step
self.self_distill_start_step = self_distill_start_step
self.step = 0
if not normlized:
@ -260,9 +260,9 @@ class BGEM3Model(nn.Module):
cross_q_dense_vecs = self._dist_gather_tensor(q_dense_vecs)
cross_p_dense_vecs = self._dist_gather_tensor(p_dense_vecs)
idxs = torch.arange(cross_q_dense_vecs.size(0), device=cross_q_dense_vecs.device, dtype=torch.long)
cross_idxs = torch.arange(cross_q_dense_vecs.size(0), device=cross_q_dense_vecs.device, dtype=torch.long)
cross_targets = idxs * (cross_p_dense_vecs.size(0) // cross_q_dense_vecs.size(0))
cross_targets = cross_idxs * (cross_p_dense_vecs.size(0) // cross_q_dense_vecs.size(0))
cross_dense_scores = self.dense_score(cross_q_dense_vecs, cross_p_dense_vecs)
loss = self.compute_loss(cross_dense_scores, cross_targets)
@ -281,7 +281,7 @@ class BGEM3Model(nn.Module):
ensemble_loss = self.compute_loss(dense_scores + 0.3 * sparse_scores + colbert_scores, targets)
loss = (loss + ensemble_loss + 0.1 * sparse_loss + colbert_loss) / 4
if self.use_self_distill and self.step > self.ensemble_distill_start_step and self.unified_finetuning:
if self.use_self_distill and self.step > self.self_distill_start_step and self.unified_finetuning:
ensemble_scores = dense_scores + 0.3 * sparse_scores + colbert_scores
teacher_targets = torch.softmax(ensemble_scores.detach(), dim=-1)
ensemble_distill_dense_loss = - torch.mean(
@ -290,8 +290,7 @@ class BGEM3Model(nn.Module):
torch.sum(torch.log_softmax(sparse_scores, dim=-1) * teacher_targets, dim=-1))
ensemble_distill_colbert_loss = - torch.mean(
torch.sum(torch.log_softmax(colbert_scores, dim=-1) * teacher_targets, dim=-1))
loss += (
ensemble_distill_dense_loss + 0.1 * ensemble_distill_sparse_loss + ensemble_distill_colbert_loss) / 3
loss += (ensemble_distill_dense_loss + 0.1 * ensemble_distill_sparse_loss + ensemble_distill_colbert_loss) / 3
loss = loss / 2
self.step += 1
else:
@ -367,6 +366,3 @@ class BGEM3ForInference(BGEM3Model):
output['colbert_vecs'] = torch.nn.functional.normalize(output['colbert_vecs'], dim=-1)
return output

155
FlagEmbedding/BGE_M3/run.py Normal file
View File

@ -0,0 +1,155 @@
import logging
import os
from pathlib import Path
import torch.distributed as dist
from transformers import AutoConfig, AutoTokenizer
from transformers import (
HfArgumentParser,
set_seed,
)
from transformers import (
TrainerCallback,
TrainingArguments,
TrainerState,
TrainerControl
)
from .arguments import ModelArguments, DataArguments, \
RetrieverTrainingArguments as TrainingArguments
from .data import SameDatasetTrainDataset, EmbedCollator
from .modeling import BGEM3Model
from .trainer import BiTrainer
logger = logging.getLogger(__name__)
class TrainerCallbackForDataRefresh(TrainerCallback):
def __init__(self, train_dataset):
self.train_dataset = train_dataset
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the end of an epoch.
"""
self.train_dataset.refresh_epoch()
def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)
logger.info("Model parameters %s", model_args)
logger.info("Data parameters %s", data_args)
# Set seed
set_seed(training_args.seed)
num_labels = 1
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=False,
)
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
cache_dir=model_args.cache_dir,
)
logger.info('Config: %s', config)
model = BGEM3Model(model_name=model_args.model_name_or_path,
normlized=model_args.normlized,
sentence_pooling_method=training_args.sentence_pooling_method,
negatives_cross_device=training_args.negatives_cross_device,
temperature=training_args.temperature,
enable_sub_batch=training_args.enable_sub_batch,
unified_finetuning=training_args.unified_finetuning,
use_self_distill=training_args.use_self_distill,
colbert_dim=training_args.colbert_dim,
self_distill_start_step=training_args.self_distill_start_step)
if training_args.fix_position_embedding:
for k, v in model.named_parameters():
if "position_embeddings" in k:
logging.info(f"Freeze the parameters for {k}")
v.requires_grad = False
if training_args.fix_encoder:
for k, v in model.named_parameters():
if "colbert_linear" in k or 'sparse_linear' in k:
logging.info(f"train the parameters for {k}")
else:
v.requires_grad = False
# print(f"===========================Rank {dist.get_rank()}: start loading data===========================")
if data_args.same_task_within_batch:
train_dataset = SameDatasetTrainDataset(args=data_args,
batch_size=training_args.per_device_train_batch_size,
seed=training_args.seed,
num_processes=training_args.world_size,
process_index=training_args.process_index)
training_args.per_device_train_batch_size = 1
training_args.dataloader_num_workers = 0 # avoid multi-processes
else:
raise NotImplementedError("Not support `same_task_within_batch=False`")
data_collator = EmbedCollator(
tokenizer,
query_max_len=data_args.query_max_len,
passage_max_len=data_args.passage_max_len
)
trainer = BiTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=data_collator,
tokenizer=tokenizer
)
if data_args.same_task_within_batch:
trainer.add_callback(TrainerCallbackForDataRefresh(train_dataset))
Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)
# Training
# print(f"===========================Rank {dist.get_rank()}: start training===========================")
trainer.train()
trainer.save_model()
# For convenience, we also re-save the tokenizer to the same directory,
# so that you can share your model easily on huggingface.co/models =)
if trainer.is_world_process_zero():
tokenizer.save_pretrained(training_args.output_dir)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,199 @@
"""
python split_data_by_length.py \
--input_path train_data \
--output_dir train_data_split \
--cache_dir .cache \
--log_name .split_log \
--length_list 0 500 1000 2000 3000 4000 5000 6000 7000 \
--model_name_or_path BAAI/bge-m3 \
--num_proc 16 \
--overwrite False
"""
import os
import json
import math
import time
import argparse
from tqdm import tqdm
from pprint import pprint
from transformers import AutoTokenizer
from datasets import load_dataset, Features, Value, Sequence
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--input_path', type=str, required=True, help='the path of input datas')
parser.add_argument('--output_dir', type=str, required=True, help='the dir of output datas')
parser.add_argument('--cache_dir', type=str, default=None, help='the cache dir')
parser.add_argument('--log_name', type=str, default='.split_log', help='the name of log file, default: `.split_log`, which will be saved to `output_dir`')
parser.add_argument('--length_list', type=int, default=[0, 500, 1000, 2000, 3000, 4000, 5000, 6000, 7000], nargs='+', help='the length list to split')
parser.add_argument('--model_name_or_path', type=str, default='BAAI/bge-m3', help='the model name or path of the tokenizer')
parser.add_argument('--num_proc', type=int, default=16, help='the number of process, default: 16')
parser.add_argument('--overwrite', action='store_true', default=False, help='whether to overwrite the output file, default: False')
args = parser.parse_args()
return args
class SplitByLengthHandler:
def __init__(self,
model_name_or_path: str,
cache_dir: str=None,
num_proc: int=16,
length_list: list=[0, 500, 1000, 2000, 3000, 4000, 5000, 6000, 7000],
overwrite: bool=False):
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.cache_dir = cache_dir
self.num_proc = num_proc
self.length_ranges_list = self._get_length_ranges_list(length_list)
self.overwrite = overwrite
pprint(self.length_ranges_list)
def _map_func(examples):
results = {}
results['idx'] = []
results['max_length'] = []
for i in range(len(examples['query'])):
results['idx'].append(i)
query = examples['query'][i]
pos, neg = examples['pos'][i], examples['neg'][i]
all_texts = [query] + pos + neg
max_len = 0
for x in all_texts:
tokenized_x = self.tokenizer(x)['input_ids']
if len(tokenized_x) > max_len:
max_len = len(tokenized_x)
results['max_length'].append(max_len)
return results
self._map_func = _map_func
@staticmethod
def _get_length_ranges_list(length_list: list):
length_ranges_list = []
length_list = sorted(length_list)
for i in range(len(length_list)):
length_l = length_list[i]
if i == len(length_list) - 1:
length_r = math.inf
else:
length_r = length_list[i + 1]
assert 0 <= length_l < length_r
length_ranges_list.append((length_l, length_r))
return length_ranges_list
def _process_dir(self, dir_path: str, output_dir: str):
assert os.path.isdir(dir_path)
log_info_list = []
for file in tqdm(os.listdir(dir_path), desc=f'processing {dir_path}'):
file_path = os.path.join(dir_path, file)
if not file_path.endswith('.jsonl'):
print(f"skip {file_path} ...")
continue
output_path = os.path.join(output_dir, '.'.join(file.split('.')[:-1]))
log_info = self._process_file(file_path, output_path)
log_info_list.append(log_info)
return log_info_list
def _process_file(self, file_path: str, output_path: str):
assert not os.path.isdir(file_path)
start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
features = Features({
'query': Value('string'),
'pos': Sequence(Value('string')),
'neg': Sequence(Value('string'))
})
kd_features = Features({
'query': Value('string'),
'pos': Sequence(Value('string')),
'neg': Sequence(Value('string')),
'pos_scores': Sequence(Value('float')),
'neg_scores': Sequence(Value('float'))
})
try:
dataset = load_dataset('json', data_files=file_path, cache_dir=self.cache_dir, features=features)['train']
except:
dataset = load_dataset('json', data_files=file_path, cache_dir=self.cache_dir, features=kd_features)['train']
mapped_dataset = dataset.map(self._map_func, batched=True, num_proc=self.num_proc)
split_info_dict = {}
for length_l, length_r in self.length_ranges_list:
save_path = output_path + f'_len-{length_l}-{length_r}.jsonl'
if os.path.exists(save_path) and not self.overwrite:
print(f'{save_path} exists, skip')
continue
idxs = mapped_dataset.filter(lambda x: length_l <= x['max_length'] < length_r, num_proc=self.num_proc)
split_dataset = dataset.select(idxs['idx'])
split_info_dict[f'len-{length_l}-{length_r}'] = len(split_dataset)
if len(split_dataset) > 0:
split_dataset.to_json(save_path, force_ascii=False)
end_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
size = len(dataset)
avg_length = sum(mapped_dataset['max_length']) / size
log_info = {
'file_name': os.path.basename(file_path),
'size': size,
'avg_length': avg_length,
'file_path': file_path,
'start_time': start_time,
'end_time': end_time,
'split_info': split_info_dict
}
return log_info
def run(self, input_path: str, output_dir: str, log_name: str=None):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if log_name is None:
log_path = os.path.join(output_dir, '.split_log')
else:
log_path = os.path.join(output_dir, log_name)
log_info_list = []
if os.path.isdir(input_path):
log_info_list = self._process_dir(input_path, output_dir)
else:
file_name = os.path.basename(input_path)
output_path = os.path.join(output_dir, '.'.join(file_name.split('.')[:-1]))
log_info = self._process_file(input_path, output_path)
log_info_list.append(log_info)
with open(log_path, 'a', encoding='utf-8') as f:
for log_info in log_info_list:
json.dump(log_info, f, ensure_ascii=False)
f.write('\n')
if __name__ == '__main__':
args = get_args()
input_path = args.input_path
output_dir = args.output_dir
log_name = args.log_name
handler = SplitByLengthHandler(
model_name_or_path=args.model_name_or_path,
cache_dir=args.cache_dir,
num_proc=args.num_proc,
length_list=args.length_list if isinstance(args.length_list, list) else [args.length_list],
overwrite=args.overwrite
)
handler.run(
input_path=input_path,
output_dir=output_dir,
log_name=log_name
)
print('\nDONE!')

View File

@ -0,0 +1,51 @@
from sentence_transformers import SentenceTransformer, models
from transformers.trainer import *
def save_ckpt_for_sentence_transformers(ckpt_dir, pooling_mode: str = 'cls', normlized: bool=True):
word_embedding_model = models.Transformer(ckpt_dir)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode=pooling_mode)
if normlized:
normlize_layer = models.Normalize()
model = SentenceTransformer(modules=[word_embedding_model, pooling_model, normlize_layer], device='cpu')
else:
model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device='cpu')
model.save(ckpt_dir)
class BiTrainer(Trainer):
def _save(self, output_dir: Optional[str] = None, state_dict=None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info("Saving model checkpoint to %s", output_dir)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not hasattr(self.model, 'save'):
raise NotImplementedError(
f'MODEL {self.model.__class__.__name__} '
f'does not support save interface')
else:
self.model.save(output_dir)
if self.tokenizer is not None and self.is_world_process_zero():
self.tokenizer.save_pretrained(output_dir)
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
# save the checkpoint for sentence-transformers library
if self.is_world_process_zero():
save_ckpt_for_sentence_transformers(output_dir,
pooling_mode=self.args.sentence_pooling_method,
normlized=self.args.normlized)
def compute_loss(self, model, inputs, return_outputs=False):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
outputs = model(**inputs)
loss = outputs.loss
return (loss, outputs) if return_outputs else loss

View File

@ -0,0 +1,107 @@
# Unified Finetune
In this example, we show how to perform unified fine-tuning based on [`BAAI/bge-m3`](https://huggingface.co/BAAI/bge-m3) with your data.
## 1. Installation
- with pip
```
pip install -U FlagEmbedding
```
- from source
```
git clone https://github.com/FlagOpen/FlagEmbedding.git
cd FlagEmbedding
pip install .
```
For development, install as editable:
```
pip install -e .
```
## 2. Data format
Training data should be a jsonl file, where each line is a dict like this:
```
{"query": str, "pos": List[str], "neg":List[str]}
```
`query` is the query, and `pos` is a list of positive texts, `neg` is a list of negative texts.
If you want to use knowledge distillation, each line of your jsonl file should be like this:
```
{"query": str, "pos": List[str], "neg":List[str], "pos_scores": List[float], "neg_scores": List[float]}
```
`pos_scores` is a list of positive scores, where `pos_scores[i]` is the score between `query` and `pos[i]` from the teacher model. `neg_scores` is a list of negative scores, where `neg_scores[i]` is the score between `query` and `neg[i]` from the teacher model.
See [toy_train_data](./toy_train_data) for an example of training data.
### Use efficient batching strategy [Optional]
(*Optional*) If you want to use **efficient batching strategy** (for more details, please refer to [our paper](https://arxiv.org/pdf/2402.03216.pdf)), you should use this [script](../../BGE_M3/split_data_by_length.py) to split your data to different parts by sequence length before training. Here's an example of how to use this script to split your data to different parts by sequence length:
```bash
python split_data_by_length.py \
--input_path train_data \
--output_dir train_data_split \
--cache_dir .cache \
--log_name .split_log \
--length_list 0 500 1000 2000 3000 4000 5000 6000 7000 \
--model_name_or_path BAAI/bge-m3 \
--num_proc 16 \
--overwrite False
```
`input_path` is the path of jsonl file or the directory containing some jsonl files. `output_dir` is the directory where the split files (`*_len-0-500.jsonl`, `*_len-500-1000.jsonl`, etc.) are saved. `output_dir/log_name` is the log of split data. `length_list` is the list of sequence length. `model_name_or_path` is used to tokenize the data. `num_proc` is the number of processes to use. `overwrite` is whether to overwrite the existing files.
For example, if there are two jsonl files `train_data1.jsonl` and `train_data2.jsonl` in `train_data` directory, then after running the above script, there will be some split files in `train_data_split` like this:
```
train_data_split
├── train_data1_0-500.jsonl
├── train_data1_500-1000.jsonl
├── train_data1_1000-2000.jsonl
├── train_data1_2000-3000.jsonl
├── train_data1_3000-4000.jsonl
├── train_data1_4000-5000.jsonl
├── train_data1_5000-6000.jsonl
├── train_data1_6000-7000.jsonl
├── train_data1_7000-inf.jsonl
├── train_data2_0-500.jsonl
├── train_data2_500-1000.jsonl
├── train_data2_1000-2000.jsonl
├── train_data2_2000-3000.jsonl
├── train_data2_3000-4000.jsonl
├── train_data2_4000-5000.jsonl
├── train_data2_5000-6000.jsonl
├── train_data2_6000-7000.jsonl
├── train_data2_7000-inf.jsonl
```
Note that if there's no data in a specific range, the corresponding file will not be created.
## 3. Train
> **Note**: If you only want to fine-tune the dense embedding of `BAAI/bge-m3`, you can refer to [here](../finetune/README.md).
If you want to perform unified fine-tuning based on `BAAI/bge-m3`, please refer to [this script](./unified_finetune_bge-m3_exmaple.sh). In this script, we use `deepspeed` to perform distributed training. Learn more about `deepspeed` at https://www.deepspeed.ai/getting-started/. Note that there are some important parameters to be modified in this script:
- `HOST_FILE_CONTENT`: Machines and GPUs for training. If you want to use multiple machines for training, please refer to https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node (note that you should configure `pdsh` and `ssh` properly).
- `DS_CONFIG_FILE`: Path of deepspeed config file. [Here](../finetune/ds_config.json) is an example of `ds_config.json`.
- `DATA_PATH`: One or more paths of training data. **Each path must be a directory containing one or more jsonl files**.
- `DEFAULT_BATCH_SIZE`: Default batch size for training. If you use efficient batching strategy, which means you have split your data to different parts by sequence length, then the batch size for each part will be decided by the `get_file_batch_size()` function in [`BGE_M3/data.py`](../../FlagEmbedding/BGE_M3/data.py). Before starting training, you should set the corresponding batch size for each part in this function according to the GPU memory of your machines. `DEFAULT_BATCH_SIZE` will be used for the part whose sequence length is not in the `get_file_batch_size()` function.
- `EPOCHS`: Number of training epochs.
- `LEARNING_RATE`: The initial learning rate.
- `SAVE_PATH`: Path of saving finetuned model.
You should set these parameters appropriately.
For more detaild arguments setting, please refer to [`BGE_M3/arguments.py`](../../FlagEmbedding/BGE_M3/arguments.py).

View File

@ -0,0 +1,10 @@
{"query": "Five women walk along a beach wearing flip-flops.", "pos": ["Some women with flip-flops on, are walking along the beach"], "neg": ["The 4 women are sitting on the beach.", "There was a reform in 1996.", "She's not going to court to clear her record.", "The man is talking about hawaii.", "A woman is standing outside.", "The battle was over. ", "A group of people plays volleyball."]}
{"query": "A woman standing on a high cliff on one leg looking over a river.", "pos": ["A woman is standing on a cliff."], "neg": ["A woman sits on a chair.", "George Bush told the Republicans there was no way he would let them even consider this foolish idea, against his top advisors advice.", "The family was falling apart.", "no one showed up to the meeting", "A boy is sitting outside playing in the sand.", "Ended as soon as I received the wire.", "A child is reading in her bedroom."]}
{"query": "Two woman are playing instruments; one a clarinet, the other a violin.", "pos": ["Some people are playing a tune."], "neg": ["Two women are playing a guitar and drums.", "A man is skiing down a mountain.", "The fatal dose was not taken when the murderer thought it would be.", "Person on bike", "The girl is standing, leaning against the archway.", "A group of women watch soap operas.", "No matter how old people get they never forget. "]}
{"query": "A girl with a blue tank top sitting watching three dogs.", "pos": ["A girl is wearing blue."], "neg": ["A girl is with three cats.", "The people are watching a funeral procession.", "The child is wearing black.", "Financing is an issue for us in public schools.", "Kids at a pool.", "It is calming to be assaulted.", "I face a serious problem at eighteen years old. "]}
{"query": "A yellow dog running along a forest path.", "pos": ["a dog is running"], "neg": ["a cat is running", "Steele did not keep her original story.", "The rule discourages people to pay their child support.", "A man in a vest sits in a car.", "Person in black clothing, with white bandanna and sunglasses waits at a bus stop.", "Neither the Globe or Mail had comments on the current state of Canada's road system. ", "The Spring Creek facility is old and outdated."]}
{"query": "It sets out essential activities in each phase along with critical factors related to those activities.", "pos": ["Critical factors for essential activities are set out."], "neg": ["It lays out critical activities but makes no provision for critical factors related to those activities.", "People are assembled in protest.", "The state would prefer for you to do that.", "A girl sits beside a boy.", "Two males are performing.", "Nobody is jumping", "Conrad was being plotted against, to be hit on the head."]}
{"query": "A man giving a speech in a restaurant.", "pos": ["A person gives a speech."], "neg": ["The man sits at the table and eats food.", "This is definitely not an endorsement.", "They sold their home because they were retiring and not because of the loan.", "The seal of Missouri is perfect.", "Someone is raising their hand.", "An athlete is competing in the 1500 meter swimming competition.", "Two men watching a magic show."]}
{"query": "Indians having a gathering with coats and food and drinks.", "pos": ["A group of Indians are having a gathering with food and drinks"], "neg": ["A group of Indians are having a funeral", "It is only staged on Winter afternoons in Palma's large bullring.", "Right information can empower the legal service practices and the justice system. ", "Meanwhile, the mainland was empty of population.", "Two children is sleeping.", "a fisherman is trying to catch a monkey", "the people are in a train"]}
{"query": "A woman with violet hair rides her bicycle outside.", "pos": ["A woman is riding her bike."], "neg": ["A woman is jogging in the park.", "The street was lined with white-painted houses.", "A group watches a movie inside.", "man at picnics cut steak", "Several chefs are sitting down and talking about food.", "The Commission notes that no significant alternatives were considered.", "We ran out of firewood and had to use pine needles for the fire."]}
{"query": "A man pulls two women down a city street in a rickshaw.", "pos": ["A man is in a city."], "neg": ["A man is a pilot of an airplane.", "It is boring and mundane.", "The morning sunlight was shining brightly and it was warm. ", "Two people jumped off the dock.", "People watching a spaceship launch.", "Mother Teresa is an easy choice.", "It's worth being able to go at a pace you prefer."]}

View File

@ -0,0 +1,10 @@
{"query": "Five women walk along a beach wearing flip-flops.", "pos": ["Some women with flip-flops on, are walking along the beach"], "neg": ["The 4 women are sitting on the beach.", "There was a reform in 1996.", "She's not going to court to clear her record.", "The man is talking about hawaii.", "A woman is standing outside.", "The battle was over. ", "A group of people plays volleyball."]}
{"query": "A woman standing on a high cliff on one leg looking over a river.", "pos": ["A woman is standing on a cliff."], "neg": ["A woman sits on a chair.", "George Bush told the Republicans there was no way he would let them even consider this foolish idea, against his top advisors advice.", "The family was falling apart.", "no one showed up to the meeting", "A boy is sitting outside playing in the sand.", "Ended as soon as I received the wire.", "A child is reading in her bedroom."]}
{"query": "Two woman are playing instruments; one a clarinet, the other a violin.", "pos": ["Some people are playing a tune."], "neg": ["Two women are playing a guitar and drums.", "A man is skiing down a mountain.", "The fatal dose was not taken when the murderer thought it would be.", "Person on bike", "The girl is standing, leaning against the archway.", "A group of women watch soap operas.", "No matter how old people get they never forget. "]}
{"query": "A girl with a blue tank top sitting watching three dogs.", "pos": ["A girl is wearing blue."], "neg": ["A girl is with three cats.", "The people are watching a funeral procession.", "The child is wearing black.", "Financing is an issue for us in public schools.", "Kids at a pool.", "It is calming to be assaulted.", "I face a serious problem at eighteen years old. "]}
{"query": "A yellow dog running along a forest path.", "pos": ["a dog is running"], "neg": ["a cat is running", "Steele did not keep her original story.", "The rule discourages people to pay their child support.", "A man in a vest sits in a car.", "Person in black clothing, with white bandanna and sunglasses waits at a bus stop.", "Neither the Globe or Mail had comments on the current state of Canada's road system. ", "The Spring Creek facility is old and outdated."]}
{"query": "It sets out essential activities in each phase along with critical factors related to those activities.", "pos": ["Critical factors for essential activities are set out."], "neg": ["It lays out critical activities but makes no provision for critical factors related to those activities.", "People are assembled in protest.", "The state would prefer for you to do that.", "A girl sits beside a boy.", "Two males are performing.", "Nobody is jumping", "Conrad was being plotted against, to be hit on the head."]}
{"query": "A man giving a speech in a restaurant.", "pos": ["A person gives a speech."], "neg": ["The man sits at the table and eats food.", "This is definitely not an endorsement.", "They sold their home because they were retiring and not because of the loan.", "The seal of Missouri is perfect.", "Someone is raising their hand.", "An athlete is competing in the 1500 meter swimming competition.", "Two men watching a magic show."]}
{"query": "Indians having a gathering with coats and food and drinks.", "pos": ["A group of Indians are having a gathering with food and drinks"], "neg": ["A group of Indians are having a funeral", "It is only staged on Winter afternoons in Palma's large bullring.", "Right information can empower the legal service practices and the justice system. ", "Meanwhile, the mainland was empty of population.", "Two children is sleeping.", "a fisherman is trying to catch a monkey", "the people are in a train"]}
{"query": "A woman with violet hair rides her bicycle outside.", "pos": ["A woman is riding her bike."], "neg": ["A woman is jogging in the park.", "The street was lined with white-painted houses.", "A group watches a movie inside.", "man at picnics cut steak", "Several chefs are sitting down and talking about food.", "The Commission notes that no significant alternatives were considered.", "We ran out of firewood and had to use pine needles for the fire."]}
{"query": "A man pulls two women down a city street in a rickshaw.", "pos": ["A man is in a city."], "neg": ["A man is a pilot of an airplane.", "It is boring and mundane.", "The morning sunlight was shining brightly and it was warm. ", "Two people jumped off the dock.", "People watching a spaceship launch.", "Mother Teresa is an easy choice.", "It's worth being able to go at a pace you prefer."]}

View File

@ -0,0 +1,90 @@
#!/bin/bash
# Set root path
ROOT=/home
# Set training machines
# For more details, refer to https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node
HOST_FILE_CONTENT="\
localhost slots=8\n\
"
HOST_FILE=hostfile
printf "$HOST_FILE_CONTENT" > $HOST_FILE
DISTRIBUTED_ARGS="--hostfile $HOST_FILE"
export LAUNCHER="deepspeed \
$DISTRIBUTED_ARGS \
"
# Set cache directory
CACHE_PATH=$ROOT/datasets/.cache
# Set path of deepspeed config file
# For more details, refer to https://huggingface.co/docs/transformers/main_classes/deepspeed#zero
DS_CONFIG_FILE=$ROOT/train/ds_config.json
# Set group size of training
GROUP_SIZE=8
# Set paths of training data. Every path **must be a directory path**.
DATA_PATH="
$ROOT/datasets/toy_train_data \
"
# Set default batch size for training.
# If you want to use effient batching strategy, you should use the script `split_data_by_length.py` to split your data by sequence length firstly. Then the batch size for every batch will depend on its sequence length range, such as len-0-500: 48, len-500-1000: 32, etc., which are defined in `get_file_batch_size()` in `BGE_M3/data.py`.
DEFAULT_BATCH_SIZE=8
# Set number of training epochs.
# For more details, refer to https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
EPOCHS=1
MAX_STEPS=-1
# Set base model and save path
BASE_MODEL=BAAI/bge-m3
SAVE_PATH=$ROOT/models/bge-m3_finetuned
mkdir -p $SAVE_PATH
# Set learning rate
LEARNING_RATE=1e-6
full_options="
--knowledge_distillation True \
--output_dir $SAVE_PATH \
--model_name_or_path $BASE_MODEL \
--normlized True \
--temperature 0.02 \
--do_train \
--train_data $DATA_PATH \
--cache_path $CACHE_PATH \
--per_device_train_batch_size $DEFAULT_BATCH_SIZE \
--query_max_len 512 \
--passage_max_len 8192 \
--fp16 \
--save_steps 1500 \
--train_group_size $GROUP_SIZE \
--learning_rate $LEARNING_RATE \
--num_train_epochs $EPOCHS \
--max_steps $MAX_STEPS \
--negatives_cross_device False \
--logging_steps 10 \
--warmup_ratio 0.1 \
--weight_decay 0.01 \
--overwrite_output_dir True \
--gradient_checkpointing \
--sentence_pooling_method cls \
--same_task_within_batch True \
--shuffle_ratio 0.002 \
--enable_sub_batch True \
--deepspeed ${DS_CONFIG_FILE} \
--ddp_timeout 1800 \
--unified_finetuning True \
--use_self_distill True \
--fix_encoder False \
"
run_cmd="$LAUNCHER --module FlagEmbedding.BGE_M3.run ${full_options}"
echo ${run_cmd}
eval ${run_cmd} 2>&1 | tee $SAVE_PATH/output.log
set +x