mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2026-01-08 13:11:35 +00:00
Upload code and example of fine-tuning bge-m3
This commit is contained in:
parent
0ad0dee592
commit
8d50e27d30
@ -1 +1,2 @@
|
||||
from .modeling import BGEM3Model, BGEM3ForInference
|
||||
from .modeling import BGEM3Model, BGEM3ForInference, EncoderOutput
|
||||
from .trainer import BiTrainer
|
||||
93
FlagEmbedding/BGE_M3/arguments.py
Normal file
93
FlagEmbedding/BGE_M3/arguments.py
Normal file
@ -0,0 +1,93 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
||||
)
|
||||
normlized: bool = field(default=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
knowledge_distillation: bool = field(
|
||||
default=False, metadata={"help": "Use knowledge distillation when `pos_scores` and `neg_scores` are in features of training data"}
|
||||
)
|
||||
train_data: str = field(
|
||||
default=None, metadata={"help": "One or more paths to training data", "nargs": "+"}
|
||||
)
|
||||
cache_path: Optional[str] = field(
|
||||
default=None, metadata={"help": "Where do you want to store the cached data"}
|
||||
)
|
||||
train_group_size: int = field(default=8)
|
||||
|
||||
query_max_len: int = field(
|
||||
default=32,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
|
||||
passage_max_len: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
|
||||
max_example_num_per_dataset: int = field(
|
||||
default=None, metadata={"help": "the max number of examples for each dataset"}
|
||||
)
|
||||
|
||||
query_instruction_for_retrieval: str= field(
|
||||
default=None, metadata={"help": "instruction for query"}
|
||||
)
|
||||
passage_instruction_for_retrieval: str = field(
|
||||
default=None, metadata={"help": "instruction for passage"}
|
||||
)
|
||||
|
||||
same_task_within_batch: bool = field(
|
||||
default=False, metadata={"help": "All samples in the same batch comes from the same task."}
|
||||
)
|
||||
shuffle_ratio: float = field(
|
||||
default=0.0, metadata={"help": "The ratio of shuffling the text"}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
for train_dir in self.train_data:
|
||||
if not os.path.exists(train_dir):
|
||||
raise FileNotFoundError(f"cannot find file: {train_dir}, please set a true path")
|
||||
|
||||
@dataclass
|
||||
class RetrieverTrainingArguments(TrainingArguments):
|
||||
negatives_cross_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
|
||||
temperature: Optional[float] = field(default=0.02)
|
||||
fix_position_embedding: bool = field(default=False, metadata={"help": "Freeze the parameters of position embeddings"})
|
||||
sentence_pooling_method: str = field(default='cls', metadata={"help": "the pooling method, should be cls or mean"})
|
||||
enable_sub_batch: bool = field(default=True, metadata={"help": "Freeze the parameters of position embeddings"})
|
||||
|
||||
unified_finetuning: bool = field(default=False, metadata={"help": "use unify fine-tuning"})
|
||||
use_self_distill: bool = field(default=False, metadata={"help": "use self-distill when using unify fine-tuning"})
|
||||
fix_encoder: bool = field(default=False, metadata={"help": "Freeze the parameters of encoder"})
|
||||
colbert_dim: int = field(default=-1, metadata={"help": "Dim of colbert linear"})
|
||||
self_distill_start_step: int = field(default=-1, metadata={"help": "Num of step when using self-distill"})
|
||||
303
FlagEmbedding/BGE_M3/data.py
Normal file
303
FlagEmbedding/BGE_M3/data.py
Normal file
@ -0,0 +1,303 @@
|
||||
import math
|
||||
import os.path
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
import torch
|
||||
import numpy as np
|
||||
import datasets
|
||||
from pprint import pprint
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import DataCollatorWithPadding
|
||||
from transformers import PreTrainedTokenizer, BatchEncoding
|
||||
import torch.distributed as dist
|
||||
|
||||
from .arguments import DataArguments
|
||||
|
||||
|
||||
class SameDatasetTrainDataset(Dataset):
|
||||
"""Dataset to yield a batch of data at one time. All samples in the same batch comes from the same task.
|
||||
"""
|
||||
def __init__(self, args: DataArguments, batch_size: int, seed: int, process_index: int=0, num_processes: int=1):
|
||||
train_datasets = []
|
||||
each_data_inxs = []
|
||||
batch_size_inxs = []
|
||||
pqloss_flag = []
|
||||
cur_all_num = 0
|
||||
|
||||
SMALL_THRESHOLD = 200
|
||||
DROP_THRESHOLD = 200
|
||||
|
||||
context_feat = datasets.Features({
|
||||
'query': datasets.Value('string'),
|
||||
'pos': datasets.Sequence(datasets.Value('string')),
|
||||
'neg': datasets.Sequence(datasets.Value('string'))
|
||||
})
|
||||
context_feat_kd = datasets.Features({
|
||||
'query': datasets.Value('string'),
|
||||
'pos': datasets.Sequence(datasets.Value('string')),
|
||||
'neg': datasets.Sequence(datasets.Value('string')),
|
||||
'pos_scores': datasets.Sequence(datasets.Value('float')),
|
||||
'neg_scores': datasets.Sequence(datasets.Value('float')),
|
||||
})
|
||||
assert isinstance(args.train_data, list) and len(args.train_data) >= 1
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
self.print_batch_size(batch_size=batch_size, train_group_size=args.train_group_size)
|
||||
|
||||
for data_dir in args.train_data:
|
||||
if not os.path.isdir(data_dir):
|
||||
raise FileNotFoundError(f"{data_dir} is a file, not a directionary")
|
||||
|
||||
small_datasets = []
|
||||
small_batch_size = math.inf
|
||||
|
||||
# Add `parallel_` in `data_dir` to indicate that this dataset is parallel corpus
|
||||
flag = 'parallel_' in data_dir
|
||||
for file in os.listdir(data_dir):
|
||||
if not (file.endswith('.json') or file.endswith('.jsonl')):
|
||||
continue
|
||||
|
||||
file_path = os.path.join(data_dir, file)
|
||||
if dist.get_rank() == 0:
|
||||
print(f'loading data from {file_path} ...')
|
||||
try:
|
||||
temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=args.cache_path, features=context_feat)
|
||||
except:
|
||||
temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=args.cache_path, features=context_feat_kd)
|
||||
if not args.knowledge_distillation:
|
||||
temp_dataset = temp_dataset.remove_columns(['pos_scores', 'neg_scores'])
|
||||
|
||||
if len(temp_dataset) == 0:
|
||||
continue
|
||||
elif len(temp_dataset) < SMALL_THRESHOLD:
|
||||
small_datasets.append(temp_dataset)
|
||||
small_batch_size = min(small_batch_size, self.get_file_batch_size(file, batch_size, train_group_size=args.train_group_size))
|
||||
else:
|
||||
if args.max_example_num_per_dataset is not None and len(temp_dataset) > args.max_example_num_per_dataset:
|
||||
temp_dataset = temp_dataset.select(
|
||||
random.sample(list(range(len(temp_dataset))), args.max_example_num_per_dataset))
|
||||
train_datasets.append(temp_dataset)
|
||||
each_data_inxs.append(np.arange(len(temp_dataset)) + cur_all_num)
|
||||
cur_all_num += len(temp_dataset)
|
||||
batch_size_inxs.append(self.get_file_batch_size(file, batch_size, train_group_size=args.train_group_size))
|
||||
pqloss_flag.append(flag)
|
||||
|
||||
if len(small_datasets) > 0:
|
||||
small_dataset = datasets.concatenate_datasets(small_datasets)
|
||||
if len(small_dataset) >= DROP_THRESHOLD:
|
||||
train_datasets.append(small_dataset)
|
||||
each_data_inxs.append(np.arange(len(small_dataset)) + cur_all_num)
|
||||
cur_all_num += len(small_dataset)
|
||||
batch_size_inxs.append(small_batch_size)
|
||||
pqloss_flag.append(flag)
|
||||
|
||||
self.dataset = datasets.concatenate_datasets(train_datasets)
|
||||
self.each_data_inxs = each_data_inxs
|
||||
self.datasets_inxs = np.arange(len(each_data_inxs))
|
||||
self.batch_size_inxs = batch_size_inxs
|
||||
self.pqloss_flag = pqloss_flag
|
||||
|
||||
self.process_index = process_index
|
||||
self.num_processes = num_processes
|
||||
self.args = args
|
||||
self.shuffle_ratio = args.shuffle_ratio
|
||||
|
||||
self.deterministic_generator = np.random.default_rng(seed)
|
||||
self.step = 0
|
||||
self.refresh_epoch()
|
||||
|
||||
def print_batch_size(self, batch_size: int, train_group_size: int):
|
||||
length_list = ['0-500', '500-1000', '1000-2000', '2000-3000', '3000-4000', '4000-5000', '5000-6000', '6000-7000', '7000-inf']
|
||||
batch_size_dict = {
|
||||
k: self.get_file_batch_size(f"len-{k}.jsonl", batch_size, train_group_size) for k in length_list
|
||||
}
|
||||
batch_size_list = [
|
||||
f'{length}:\t{batch_size_dict[length]}' for length in length_list
|
||||
]
|
||||
print("=========================")
|
||||
print("Batch Size Dict:")
|
||||
pprint(batch_size_list)
|
||||
print("=========================")
|
||||
|
||||
@staticmethod
|
||||
def get_file_batch_size(file: str, batch_size: int, train_group_size: int):
|
||||
if train_group_size == 8:
|
||||
# 80GB
|
||||
if 'len-0-500.jsonl' in file:
|
||||
return 48
|
||||
elif 'len-500-1000.jsonl' in file:
|
||||
return 32
|
||||
elif 'len-1000-2000.jsonl' in file:
|
||||
return 20
|
||||
elif 'len-2000-3000.jsonl' in file:
|
||||
return 18
|
||||
elif 'len-3000-4000.jsonl' in file:
|
||||
return 14
|
||||
elif 'len-4000-5000.jsonl' in file:
|
||||
return 14
|
||||
elif 'len-5000-6000.jsonl' in file:
|
||||
return 12
|
||||
elif 'len-6000-7000.jsonl' in file:
|
||||
return 10
|
||||
elif 'len-7000-inf.jsonl' in file:
|
||||
return 8
|
||||
else:
|
||||
return batch_size
|
||||
elif train_group_size == 1:
|
||||
# 80GB
|
||||
if 'len-0-500.jsonl' in file:
|
||||
return 700
|
||||
elif 'len-500-1000.jsonl' in file:
|
||||
return 570
|
||||
elif 'len-1000-2000.jsonl' in file:
|
||||
return 388
|
||||
elif 'len-2000-3000.jsonl' in file:
|
||||
return 288
|
||||
elif 'len-3000-4000.jsonl' in file:
|
||||
return 224
|
||||
elif 'len-4000-5000.jsonl' in file:
|
||||
return 180
|
||||
elif 'len-5000-6000.jsonl' in file:
|
||||
return 157
|
||||
elif 'len-6000-7000.jsonl' in file:
|
||||
return 128
|
||||
elif 'len-7000-inf.jsonl' in file:
|
||||
return 104
|
||||
else:
|
||||
return batch_size
|
||||
else:
|
||||
return batch_size
|
||||
|
||||
def refresh_epoch(self):
|
||||
print(f'---------------------------*Rank {self.process_index}: refresh data---------------------------')
|
||||
self.deterministic_generator.shuffle(self.datasets_inxs)
|
||||
# Dynamically adjust batch size
|
||||
batch_datas = []
|
||||
for dataset_inx in self.datasets_inxs:
|
||||
self.deterministic_generator.shuffle(self.each_data_inxs[dataset_inx])
|
||||
cur_batch_size = self.batch_size_inxs[dataset_inx]*self.num_processes
|
||||
flag = self.pqloss_flag[dataset_inx]
|
||||
for start_index in range(0, len(self.each_data_inxs[dataset_inx]), cur_batch_size):
|
||||
# judge the last batch's length
|
||||
if len(self.each_data_inxs[dataset_inx]) - start_index < 2 * self.num_processes:
|
||||
break
|
||||
batch_datas.append((self.each_data_inxs[dataset_inx][start_index:start_index+cur_batch_size], flag))
|
||||
self.deterministic_generator.shuffle(batch_datas)
|
||||
self.batch_datas = batch_datas
|
||||
self.step = 0
|
||||
|
||||
def __getitem__(self, _):
|
||||
batch_indices, pqloss_flag = self.batch_datas[self.step]
|
||||
cur_batch_size = int(len(batch_indices) / self.num_processes)
|
||||
batch_indices = batch_indices[self.process_index * cur_batch_size: (self.process_index + 1) * cur_batch_size]
|
||||
batch_data = self.dataset[batch_indices]
|
||||
self.step += 1
|
||||
queries, passages, teacher_scores = self.create_batch_data(batch_raw_data=batch_data)
|
||||
# print('rank, step, flag, query, passage:', dist.get_rank(), self.step, pqloss_flag, queries, passages)
|
||||
return queries, passages, teacher_scores, pqloss_flag
|
||||
|
||||
def shuffle_text(self, text):
|
||||
if self.shuffle_ratio > 0 and len(text) > 100 and random.random() < self.shuffle_ratio:
|
||||
split_text = []
|
||||
chunk_size = len(text)//3 + 1
|
||||
for i in range(0, len(split_text), chunk_size):
|
||||
split_text.append(text[i:i+chunk_size])
|
||||
random.shuffle(split_text)
|
||||
return " ".join(split_text)
|
||||
else:
|
||||
return text
|
||||
|
||||
def create_batch_data(self, batch_raw_data):
|
||||
queries, passages = [], []
|
||||
teacher_scores = []
|
||||
for i in range(len(batch_raw_data['query'])):
|
||||
queries.append(batch_raw_data['query'][i])
|
||||
|
||||
pos_inx = random.choice(list(range(len(batch_raw_data['pos'][i]))))
|
||||
passages.append(self.shuffle_text(batch_raw_data['pos'][i][pos_inx]))
|
||||
if 'pos_scores' in batch_raw_data and batch_raw_data['pos_scores'][i] is not None:
|
||||
teacher_scores.append(batch_raw_data['pos_scores'][i][pos_inx])
|
||||
|
||||
neg_inx_set = list(range(len(batch_raw_data['neg'][i])))
|
||||
if len(batch_raw_data['neg'][i]) < self.args.train_group_size - 1:
|
||||
num = math.ceil((self.args.train_group_size - 1) / len(batch_raw_data['neg'][i]))
|
||||
neg_inxs = random.sample(neg_inx_set * num, self.args.train_group_size - 1)
|
||||
else:
|
||||
neg_inxs = random.sample(neg_inx_set, self.args.train_group_size - 1)
|
||||
|
||||
if 'neg_scores' in batch_raw_data and batch_raw_data['neg_scores'][i] is not None:
|
||||
neg_scores = [(x, batch_raw_data['neg_scores'][i][x]) for x in neg_inxs]
|
||||
neg_scores = sorted(neg_scores, key=lambda x:x[1], reverse=True)
|
||||
neg_inxs = [x[0] for x in neg_scores]
|
||||
teacher_scores.extend([x[1] for x in neg_scores])
|
||||
|
||||
negs = [batch_raw_data['neg'][i][x] for x in neg_inxs]
|
||||
passages.extend(negs)
|
||||
|
||||
if len(teacher_scores) > 0 and len(passages) > 0:
|
||||
assert len(teacher_scores) == len(passages)
|
||||
|
||||
if self.args.query_instruction_for_retrieval is not None:
|
||||
queries = [self.args.query_instruction_for_retrieval+q for q in queries]
|
||||
if self.args.passage_instruction_for_retrieval is not None:
|
||||
passages = [self.args.passage_instruction_for_retrieval+p for p in passages]
|
||||
|
||||
if len(teacher_scores) == 0:
|
||||
teacher_scores = None
|
||||
return queries, passages, teacher_scores
|
||||
|
||||
def __len__(self):
|
||||
return len(self.batch_datas) * self.num_processes
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbedCollator(DataCollatorWithPadding):
|
||||
"""
|
||||
Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
|
||||
and pass batch separately to the actual collator.
|
||||
Abstract out data detail for the model.
|
||||
"""
|
||||
query_max_len: int = 32
|
||||
passage_max_len: int = 128
|
||||
|
||||
def __call__(self, features):
|
||||
query = [f[0] for f in features]
|
||||
passage = [f[1] for f in features]
|
||||
|
||||
teacher_scores = None
|
||||
if len(features[0]) > 2:
|
||||
teacher_scores = [f[2] for f in features]
|
||||
if teacher_scores[0] is None:
|
||||
teacher_scores = None
|
||||
else:
|
||||
teacher_scores = torch.FloatTensor(teacher_scores)
|
||||
|
||||
flag = None
|
||||
if len(features[0]) == 4:
|
||||
flag = [f[3] for f in features][0]
|
||||
|
||||
if isinstance(query[0], list):
|
||||
query = sum(query, [])
|
||||
if isinstance(passage[0], list):
|
||||
passage = sum(passage, [])
|
||||
|
||||
q_collated = self.tokenizer(
|
||||
query,
|
||||
# padding='max_length', # used for adjusting the batch size in `get_file_batch_size()`
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.query_max_len,
|
||||
return_tensors="pt",
|
||||
)
|
||||
d_collated = self.tokenizer(
|
||||
passage,
|
||||
# padding='max_length', # used for adjusting the batch size in `get_file_batch_size()`
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.passage_max_len,
|
||||
return_tensors="pt",
|
||||
)
|
||||
if teacher_scores is not None:
|
||||
teacher_scores = teacher_scores.reshape((len(q_collated['input_ids']), -1))
|
||||
return {"query": q_collated, "passage": d_collated, "teacher_scores": teacher_scores, "bi_directions": flag}
|
||||
@ -34,7 +34,7 @@ class BGEM3Model(nn.Module):
|
||||
unified_finetuning: bool = True,
|
||||
use_self_distill: bool = False,
|
||||
colbert_dim: int = -1,
|
||||
ensemble_distill_start_step: int = -1,
|
||||
self_distill_start_step: int = -1,
|
||||
):
|
||||
super().__init__()
|
||||
self.load_model(model_name, colbert_dim=colbert_dim)
|
||||
@ -51,7 +51,7 @@ class BGEM3Model(nn.Module):
|
||||
self.enable_sub_batch = enable_sub_batch
|
||||
self.temperature = temperature
|
||||
self.use_self_distill = use_self_distill
|
||||
self.ensemble_distill_start_step = ensemble_distill_start_step
|
||||
self.self_distill_start_step = self_distill_start_step
|
||||
|
||||
self.step = 0
|
||||
if not normlized:
|
||||
@ -260,9 +260,9 @@ class BGEM3Model(nn.Module):
|
||||
cross_q_dense_vecs = self._dist_gather_tensor(q_dense_vecs)
|
||||
cross_p_dense_vecs = self._dist_gather_tensor(p_dense_vecs)
|
||||
|
||||
idxs = torch.arange(cross_q_dense_vecs.size(0), device=cross_q_dense_vecs.device, dtype=torch.long)
|
||||
cross_idxs = torch.arange(cross_q_dense_vecs.size(0), device=cross_q_dense_vecs.device, dtype=torch.long)
|
||||
|
||||
cross_targets = idxs * (cross_p_dense_vecs.size(0) // cross_q_dense_vecs.size(0))
|
||||
cross_targets = cross_idxs * (cross_p_dense_vecs.size(0) // cross_q_dense_vecs.size(0))
|
||||
cross_dense_scores = self.dense_score(cross_q_dense_vecs, cross_p_dense_vecs)
|
||||
|
||||
loss = self.compute_loss(cross_dense_scores, cross_targets)
|
||||
@ -281,7 +281,7 @@ class BGEM3Model(nn.Module):
|
||||
ensemble_loss = self.compute_loss(dense_scores + 0.3 * sparse_scores + colbert_scores, targets)
|
||||
loss = (loss + ensemble_loss + 0.1 * sparse_loss + colbert_loss) / 4
|
||||
|
||||
if self.use_self_distill and self.step > self.ensemble_distill_start_step and self.unified_finetuning:
|
||||
if self.use_self_distill and self.step > self.self_distill_start_step and self.unified_finetuning:
|
||||
ensemble_scores = dense_scores + 0.3 * sparse_scores + colbert_scores
|
||||
teacher_targets = torch.softmax(ensemble_scores.detach(), dim=-1)
|
||||
ensemble_distill_dense_loss = - torch.mean(
|
||||
@ -290,8 +290,7 @@ class BGEM3Model(nn.Module):
|
||||
torch.sum(torch.log_softmax(sparse_scores, dim=-1) * teacher_targets, dim=-1))
|
||||
ensemble_distill_colbert_loss = - torch.mean(
|
||||
torch.sum(torch.log_softmax(colbert_scores, dim=-1) * teacher_targets, dim=-1))
|
||||
loss += (
|
||||
ensemble_distill_dense_loss + 0.1 * ensemble_distill_sparse_loss + ensemble_distill_colbert_loss) / 3
|
||||
loss += (ensemble_distill_dense_loss + 0.1 * ensemble_distill_sparse_loss + ensemble_distill_colbert_loss) / 3
|
||||
loss = loss / 2
|
||||
self.step += 1
|
||||
else:
|
||||
@ -367,6 +366,3 @@ class BGEM3ForInference(BGEM3Model):
|
||||
output['colbert_vecs'] = torch.nn.functional.normalize(output['colbert_vecs'], dim=-1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
|
||||
155
FlagEmbedding/BGE_M3/run.py
Normal file
155
FlagEmbedding/BGE_M3/run.py
Normal file
@ -0,0 +1,155 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import torch.distributed as dist
|
||||
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from transformers import (
|
||||
HfArgumentParser,
|
||||
set_seed,
|
||||
)
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
TrainerState,
|
||||
TrainerControl
|
||||
)
|
||||
|
||||
from .arguments import ModelArguments, DataArguments, \
|
||||
RetrieverTrainingArguments as TrainingArguments
|
||||
from .data import SameDatasetTrainDataset, EmbedCollator
|
||||
from .modeling import BGEM3Model
|
||||
from .trainer import BiTrainer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainerCallbackForDataRefresh(TrainerCallback):
|
||||
def __init__(self, train_dataset):
|
||||
self.train_dataset = train_dataset
|
||||
|
||||
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
"""
|
||||
Event called at the end of an epoch.
|
||||
"""
|
||||
self.train_dataset.refresh_epoch()
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
model_args: ModelArguments
|
||||
data_args: DataArguments
|
||||
training_args: TrainingArguments
|
||||
|
||||
if (
|
||||
os.path.exists(training_args.output_dir)
|
||||
and os.listdir(training_args.output_dir)
|
||||
and training_args.do_train
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
training_args.local_rank,
|
||||
training_args.device,
|
||||
training_args.n_gpu,
|
||||
bool(training_args.local_rank != -1),
|
||||
training_args.fp16,
|
||||
)
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
logger.info("Model parameters %s", model_args)
|
||||
logger.info("Data parameters %s", data_args)
|
||||
|
||||
# Set seed
|
||||
set_seed(training_args.seed)
|
||||
|
||||
num_labels = 1
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_fast=False,
|
||||
)
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
logger.info('Config: %s', config)
|
||||
|
||||
model = BGEM3Model(model_name=model_args.model_name_or_path,
|
||||
normlized=model_args.normlized,
|
||||
sentence_pooling_method=training_args.sentence_pooling_method,
|
||||
negatives_cross_device=training_args.negatives_cross_device,
|
||||
temperature=training_args.temperature,
|
||||
enable_sub_batch=training_args.enable_sub_batch,
|
||||
unified_finetuning=training_args.unified_finetuning,
|
||||
use_self_distill=training_args.use_self_distill,
|
||||
colbert_dim=training_args.colbert_dim,
|
||||
self_distill_start_step=training_args.self_distill_start_step)
|
||||
|
||||
if training_args.fix_position_embedding:
|
||||
for k, v in model.named_parameters():
|
||||
if "position_embeddings" in k:
|
||||
logging.info(f"Freeze the parameters for {k}")
|
||||
v.requires_grad = False
|
||||
if training_args.fix_encoder:
|
||||
for k, v in model.named_parameters():
|
||||
if "colbert_linear" in k or 'sparse_linear' in k:
|
||||
logging.info(f"train the parameters for {k}")
|
||||
else:
|
||||
v.requires_grad = False
|
||||
|
||||
# print(f"===========================Rank {dist.get_rank()}: start loading data===========================")
|
||||
if data_args.same_task_within_batch:
|
||||
train_dataset = SameDatasetTrainDataset(args=data_args,
|
||||
batch_size=training_args.per_device_train_batch_size,
|
||||
seed=training_args.seed,
|
||||
num_processes=training_args.world_size,
|
||||
process_index=training_args.process_index)
|
||||
training_args.per_device_train_batch_size = 1
|
||||
training_args.dataloader_num_workers = 0 # avoid multi-processes
|
||||
else:
|
||||
raise NotImplementedError("Not support `same_task_within_batch=False`")
|
||||
|
||||
data_collator = EmbedCollator(
|
||||
tokenizer,
|
||||
query_max_len=data_args.query_max_len,
|
||||
passage_max_len=data_args.passage_max_len
|
||||
)
|
||||
|
||||
trainer = BiTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
data_collator=data_collator,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
if data_args.same_task_within_batch:
|
||||
trainer.add_callback(TrainerCallbackForDataRefresh(train_dataset))
|
||||
|
||||
Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Training
|
||||
# print(f"===========================Rank {dist.get_rank()}: start training===========================")
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
# For convenience, we also re-save the tokenizer to the same directory,
|
||||
# so that you can share your model easily on huggingface.co/models =)
|
||||
if trainer.is_world_process_zero():
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
199
FlagEmbedding/BGE_M3/split_data_by_length.py
Normal file
199
FlagEmbedding/BGE_M3/split_data_by_length.py
Normal file
@ -0,0 +1,199 @@
|
||||
"""
|
||||
python split_data_by_length.py \
|
||||
--input_path train_data \
|
||||
--output_dir train_data_split \
|
||||
--cache_dir .cache \
|
||||
--log_name .split_log \
|
||||
--length_list 0 500 1000 2000 3000 4000 5000 6000 7000 \
|
||||
--model_name_or_path BAAI/bge-m3 \
|
||||
--num_proc 16 \
|
||||
--overwrite False
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import math
|
||||
import time
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
from pprint import pprint
|
||||
from transformers import AutoTokenizer
|
||||
from datasets import load_dataset, Features, Value, Sequence
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--input_path', type=str, required=True, help='the path of input datas')
|
||||
parser.add_argument('--output_dir', type=str, required=True, help='the dir of output datas')
|
||||
parser.add_argument('--cache_dir', type=str, default=None, help='the cache dir')
|
||||
parser.add_argument('--log_name', type=str, default='.split_log', help='the name of log file, default: `.split_log`, which will be saved to `output_dir`')
|
||||
parser.add_argument('--length_list', type=int, default=[0, 500, 1000, 2000, 3000, 4000, 5000, 6000, 7000], nargs='+', help='the length list to split')
|
||||
parser.add_argument('--model_name_or_path', type=str, default='BAAI/bge-m3', help='the model name or path of the tokenizer')
|
||||
parser.add_argument('--num_proc', type=int, default=16, help='the number of process, default: 16')
|
||||
parser.add_argument('--overwrite', action='store_true', default=False, help='whether to overwrite the output file, default: False')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
class SplitByLengthHandler:
|
||||
def __init__(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: str=None,
|
||||
num_proc: int=16,
|
||||
length_list: list=[0, 500, 1000, 2000, 3000, 4000, 5000, 6000, 7000],
|
||||
overwrite: bool=False):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
self.cache_dir = cache_dir
|
||||
self.num_proc = num_proc
|
||||
self.length_ranges_list = self._get_length_ranges_list(length_list)
|
||||
self.overwrite = overwrite
|
||||
|
||||
pprint(self.length_ranges_list)
|
||||
|
||||
def _map_func(examples):
|
||||
results = {}
|
||||
results['idx'] = []
|
||||
results['max_length'] = []
|
||||
for i in range(len(examples['query'])):
|
||||
results['idx'].append(i)
|
||||
|
||||
query = examples['query'][i]
|
||||
pos, neg = examples['pos'][i], examples['neg'][i]
|
||||
all_texts = [query] + pos + neg
|
||||
|
||||
max_len = 0
|
||||
for x in all_texts:
|
||||
tokenized_x = self.tokenizer(x)['input_ids']
|
||||
if len(tokenized_x) > max_len:
|
||||
max_len = len(tokenized_x)
|
||||
results['max_length'].append(max_len)
|
||||
return results
|
||||
|
||||
self._map_func = _map_func
|
||||
|
||||
@staticmethod
|
||||
def _get_length_ranges_list(length_list: list):
|
||||
length_ranges_list = []
|
||||
length_list = sorted(length_list)
|
||||
for i in range(len(length_list)):
|
||||
length_l = length_list[i]
|
||||
if i == len(length_list) - 1:
|
||||
length_r = math.inf
|
||||
else:
|
||||
length_r = length_list[i + 1]
|
||||
assert 0 <= length_l < length_r
|
||||
length_ranges_list.append((length_l, length_r))
|
||||
|
||||
return length_ranges_list
|
||||
|
||||
def _process_dir(self, dir_path: str, output_dir: str):
|
||||
assert os.path.isdir(dir_path)
|
||||
log_info_list = []
|
||||
for file in tqdm(os.listdir(dir_path), desc=f'processing {dir_path}'):
|
||||
file_path = os.path.join(dir_path, file)
|
||||
if not file_path.endswith('.jsonl'):
|
||||
print(f"skip {file_path} ...")
|
||||
continue
|
||||
|
||||
output_path = os.path.join(output_dir, '.'.join(file.split('.')[:-1]))
|
||||
log_info = self._process_file(file_path, output_path)
|
||||
log_info_list.append(log_info)
|
||||
return log_info_list
|
||||
|
||||
def _process_file(self, file_path: str, output_path: str):
|
||||
assert not os.path.isdir(file_path)
|
||||
|
||||
start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
features = Features({
|
||||
'query': Value('string'),
|
||||
'pos': Sequence(Value('string')),
|
||||
'neg': Sequence(Value('string'))
|
||||
})
|
||||
kd_features = Features({
|
||||
'query': Value('string'),
|
||||
'pos': Sequence(Value('string')),
|
||||
'neg': Sequence(Value('string')),
|
||||
'pos_scores': Sequence(Value('float')),
|
||||
'neg_scores': Sequence(Value('float'))
|
||||
})
|
||||
try:
|
||||
dataset = load_dataset('json', data_files=file_path, cache_dir=self.cache_dir, features=features)['train']
|
||||
except:
|
||||
dataset = load_dataset('json', data_files=file_path, cache_dir=self.cache_dir, features=kd_features)['train']
|
||||
mapped_dataset = dataset.map(self._map_func, batched=True, num_proc=self.num_proc)
|
||||
|
||||
split_info_dict = {}
|
||||
for length_l, length_r in self.length_ranges_list:
|
||||
save_path = output_path + f'_len-{length_l}-{length_r}.jsonl'
|
||||
if os.path.exists(save_path) and not self.overwrite:
|
||||
print(f'{save_path} exists, skip')
|
||||
continue
|
||||
|
||||
idxs = mapped_dataset.filter(lambda x: length_l <= x['max_length'] < length_r, num_proc=self.num_proc)
|
||||
split_dataset = dataset.select(idxs['idx'])
|
||||
|
||||
split_info_dict[f'len-{length_l}-{length_r}'] = len(split_dataset)
|
||||
|
||||
if len(split_dataset) > 0:
|
||||
split_dataset.to_json(save_path, force_ascii=False)
|
||||
|
||||
end_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
size = len(dataset)
|
||||
avg_length = sum(mapped_dataset['max_length']) / size
|
||||
log_info = {
|
||||
'file_name': os.path.basename(file_path),
|
||||
'size': size,
|
||||
'avg_length': avg_length,
|
||||
'file_path': file_path,
|
||||
'start_time': start_time,
|
||||
'end_time': end_time,
|
||||
'split_info': split_info_dict
|
||||
}
|
||||
return log_info
|
||||
|
||||
def run(self, input_path: str, output_dir: str, log_name: str=None):
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
if log_name is None:
|
||||
log_path = os.path.join(output_dir, '.split_log')
|
||||
else:
|
||||
log_path = os.path.join(output_dir, log_name)
|
||||
|
||||
log_info_list = []
|
||||
|
||||
if os.path.isdir(input_path):
|
||||
log_info_list = self._process_dir(input_path, output_dir)
|
||||
else:
|
||||
file_name = os.path.basename(input_path)
|
||||
output_path = os.path.join(output_dir, '.'.join(file_name.split('.')[:-1]))
|
||||
log_info = self._process_file(input_path, output_path)
|
||||
log_info_list.append(log_info)
|
||||
|
||||
with open(log_path, 'a', encoding='utf-8') as f:
|
||||
for log_info in log_info_list:
|
||||
json.dump(log_info, f, ensure_ascii=False)
|
||||
f.write('\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
input_path = args.input_path
|
||||
output_dir = args.output_dir
|
||||
log_name = args.log_name
|
||||
|
||||
handler = SplitByLengthHandler(
|
||||
model_name_or_path=args.model_name_or_path,
|
||||
cache_dir=args.cache_dir,
|
||||
num_proc=args.num_proc,
|
||||
length_list=args.length_list if isinstance(args.length_list, list) else [args.length_list],
|
||||
overwrite=args.overwrite
|
||||
)
|
||||
|
||||
handler.run(
|
||||
input_path=input_path,
|
||||
output_dir=output_dir,
|
||||
log_name=log_name
|
||||
)
|
||||
print('\nDONE!')
|
||||
51
FlagEmbedding/BGE_M3/trainer.py
Normal file
51
FlagEmbedding/BGE_M3/trainer.py
Normal file
@ -0,0 +1,51 @@
|
||||
from sentence_transformers import SentenceTransformer, models
|
||||
from transformers.trainer import *
|
||||
|
||||
|
||||
def save_ckpt_for_sentence_transformers(ckpt_dir, pooling_mode: str = 'cls', normlized: bool=True):
|
||||
word_embedding_model = models.Transformer(ckpt_dir)
|
||||
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode=pooling_mode)
|
||||
if normlized:
|
||||
normlize_layer = models.Normalize()
|
||||
model = SentenceTransformer(modules=[word_embedding_model, pooling_model, normlize_layer], device='cpu')
|
||||
else:
|
||||
model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device='cpu')
|
||||
model.save(ckpt_dir)
|
||||
|
||||
|
||||
class BiTrainer(Trainer):
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
# Save a trained model and configuration using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
if not hasattr(self.model, 'save'):
|
||||
raise NotImplementedError(
|
||||
f'MODEL {self.model.__class__.__name__} '
|
||||
f'does not support save interface')
|
||||
else:
|
||||
self.model.save(output_dir)
|
||||
if self.tokenizer is not None and self.is_world_process_zero():
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
||||
|
||||
# save the checkpoint for sentence-transformers library
|
||||
if self.is_world_process_zero():
|
||||
save_ckpt_for_sentence_transformers(output_dir,
|
||||
pooling_mode=self.args.sentence_pooling_method,
|
||||
normlized=self.args.normlized)
|
||||
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
"""
|
||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||
|
||||
Subclass and override for custom behavior.
|
||||
"""
|
||||
|
||||
outputs = model(**inputs)
|
||||
loss = outputs.loss
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
107
examples/unified_finetune/README.md
Normal file
107
examples/unified_finetune/README.md
Normal file
@ -0,0 +1,107 @@
|
||||
# Unified Finetune
|
||||
|
||||
In this example, we show how to perform unified fine-tuning based on [`BAAI/bge-m3`](https://huggingface.co/BAAI/bge-m3) with your data.
|
||||
|
||||
## 1. Installation
|
||||
|
||||
- with pip
|
||||
|
||||
```
|
||||
pip install -U FlagEmbedding
|
||||
```
|
||||
|
||||
- from source
|
||||
|
||||
```
|
||||
git clone https://github.com/FlagOpen/FlagEmbedding.git
|
||||
cd FlagEmbedding
|
||||
pip install .
|
||||
```
|
||||
For development, install as editable:
|
||||
|
||||
```
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## 2. Data format
|
||||
|
||||
Training data should be a jsonl file, where each line is a dict like this:
|
||||
|
||||
```
|
||||
{"query": str, "pos": List[str], "neg":List[str]}
|
||||
```
|
||||
|
||||
`query` is the query, and `pos` is a list of positive texts, `neg` is a list of negative texts.
|
||||
|
||||
If you want to use knowledge distillation, each line of your jsonl file should be like this:
|
||||
|
||||
```
|
||||
{"query": str, "pos": List[str], "neg":List[str], "pos_scores": List[float], "neg_scores": List[float]}
|
||||
```
|
||||
|
||||
`pos_scores` is a list of positive scores, where `pos_scores[i]` is the score between `query` and `pos[i]` from the teacher model. `neg_scores` is a list of negative scores, where `neg_scores[i]` is the score between `query` and `neg[i]` from the teacher model.
|
||||
|
||||
See [toy_train_data](./toy_train_data) for an example of training data.
|
||||
|
||||
### Use efficient batching strategy [Optional]
|
||||
|
||||
(*Optional*) If you want to use **efficient batching strategy** (for more details, please refer to [our paper](https://arxiv.org/pdf/2402.03216.pdf)), you should use this [script](../../BGE_M3/split_data_by_length.py) to split your data to different parts by sequence length before training. Here's an example of how to use this script to split your data to different parts by sequence length:
|
||||
|
||||
```bash
|
||||
python split_data_by_length.py \
|
||||
--input_path train_data \
|
||||
--output_dir train_data_split \
|
||||
--cache_dir .cache \
|
||||
--log_name .split_log \
|
||||
--length_list 0 500 1000 2000 3000 4000 5000 6000 7000 \
|
||||
--model_name_or_path BAAI/bge-m3 \
|
||||
--num_proc 16 \
|
||||
--overwrite False
|
||||
```
|
||||
|
||||
`input_path` is the path of jsonl file or the directory containing some jsonl files. `output_dir` is the directory where the split files (`*_len-0-500.jsonl`, `*_len-500-1000.jsonl`, etc.) are saved. `output_dir/log_name` is the log of split data. `length_list` is the list of sequence length. `model_name_or_path` is used to tokenize the data. `num_proc` is the number of processes to use. `overwrite` is whether to overwrite the existing files.
|
||||
|
||||
For example, if there are two jsonl files `train_data1.jsonl` and `train_data2.jsonl` in `train_data` directory, then after running the above script, there will be some split files in `train_data_split` like this:
|
||||
|
||||
```
|
||||
train_data_split
|
||||
├── train_data1_0-500.jsonl
|
||||
├── train_data1_500-1000.jsonl
|
||||
├── train_data1_1000-2000.jsonl
|
||||
├── train_data1_2000-3000.jsonl
|
||||
├── train_data1_3000-4000.jsonl
|
||||
├── train_data1_4000-5000.jsonl
|
||||
├── train_data1_5000-6000.jsonl
|
||||
├── train_data1_6000-7000.jsonl
|
||||
├── train_data1_7000-inf.jsonl
|
||||
├── train_data2_0-500.jsonl
|
||||
├── train_data2_500-1000.jsonl
|
||||
├── train_data2_1000-2000.jsonl
|
||||
├── train_data2_2000-3000.jsonl
|
||||
├── train_data2_3000-4000.jsonl
|
||||
├── train_data2_4000-5000.jsonl
|
||||
├── train_data2_5000-6000.jsonl
|
||||
├── train_data2_6000-7000.jsonl
|
||||
├── train_data2_7000-inf.jsonl
|
||||
```
|
||||
|
||||
Note that if there's no data in a specific range, the corresponding file will not be created.
|
||||
|
||||
## 3. Train
|
||||
|
||||
> **Note**: If you only want to fine-tune the dense embedding of `BAAI/bge-m3`, you can refer to [here](../finetune/README.md).
|
||||
|
||||
If you want to perform unified fine-tuning based on `BAAI/bge-m3`, please refer to [this script](./unified_finetune_bge-m3_exmaple.sh). In this script, we use `deepspeed` to perform distributed training. Learn more about `deepspeed` at https://www.deepspeed.ai/getting-started/. Note that there are some important parameters to be modified in this script:
|
||||
|
||||
- `HOST_FILE_CONTENT`: Machines and GPUs for training. If you want to use multiple machines for training, please refer to https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node (note that you should configure `pdsh` and `ssh` properly).
|
||||
- `DS_CONFIG_FILE`: Path of deepspeed config file. [Here](../finetune/ds_config.json) is an example of `ds_config.json`.
|
||||
- `DATA_PATH`: One or more paths of training data. **Each path must be a directory containing one or more jsonl files**.
|
||||
- `DEFAULT_BATCH_SIZE`: Default batch size for training. If you use efficient batching strategy, which means you have split your data to different parts by sequence length, then the batch size for each part will be decided by the `get_file_batch_size()` function in [`BGE_M3/data.py`](../../FlagEmbedding/BGE_M3/data.py). Before starting training, you should set the corresponding batch size for each part in this function according to the GPU memory of your machines. `DEFAULT_BATCH_SIZE` will be used for the part whose sequence length is not in the `get_file_batch_size()` function.
|
||||
- `EPOCHS`: Number of training epochs.
|
||||
- `LEARNING_RATE`: The initial learning rate.
|
||||
- `SAVE_PATH`: Path of saving finetuned model.
|
||||
|
||||
You should set these parameters appropriately.
|
||||
|
||||
|
||||
For more detaild arguments setting, please refer to [`BGE_M3/arguments.py`](../../FlagEmbedding/BGE_M3/arguments.py).
|
||||
@ -0,0 +1,10 @@
|
||||
{"query": "Five women walk along a beach wearing flip-flops.", "pos": ["Some women with flip-flops on, are walking along the beach"], "neg": ["The 4 women are sitting on the beach.", "There was a reform in 1996.", "She's not going to court to clear her record.", "The man is talking about hawaii.", "A woman is standing outside.", "The battle was over. ", "A group of people plays volleyball."]}
|
||||
{"query": "A woman standing on a high cliff on one leg looking over a river.", "pos": ["A woman is standing on a cliff."], "neg": ["A woman sits on a chair.", "George Bush told the Republicans there was no way he would let them even consider this foolish idea, against his top advisors advice.", "The family was falling apart.", "no one showed up to the meeting", "A boy is sitting outside playing in the sand.", "Ended as soon as I received the wire.", "A child is reading in her bedroom."]}
|
||||
{"query": "Two woman are playing instruments; one a clarinet, the other a violin.", "pos": ["Some people are playing a tune."], "neg": ["Two women are playing a guitar and drums.", "A man is skiing down a mountain.", "The fatal dose was not taken when the murderer thought it would be.", "Person on bike", "The girl is standing, leaning against the archway.", "A group of women watch soap operas.", "No matter how old people get they never forget. "]}
|
||||
{"query": "A girl with a blue tank top sitting watching three dogs.", "pos": ["A girl is wearing blue."], "neg": ["A girl is with three cats.", "The people are watching a funeral procession.", "The child is wearing black.", "Financing is an issue for us in public schools.", "Kids at a pool.", "It is calming to be assaulted.", "I face a serious problem at eighteen years old. "]}
|
||||
{"query": "A yellow dog running along a forest path.", "pos": ["a dog is running"], "neg": ["a cat is running", "Steele did not keep her original story.", "The rule discourages people to pay their child support.", "A man in a vest sits in a car.", "Person in black clothing, with white bandanna and sunglasses waits at a bus stop.", "Neither the Globe or Mail had comments on the current state of Canada's road system. ", "The Spring Creek facility is old and outdated."]}
|
||||
{"query": "It sets out essential activities in each phase along with critical factors related to those activities.", "pos": ["Critical factors for essential activities are set out."], "neg": ["It lays out critical activities but makes no provision for critical factors related to those activities.", "People are assembled in protest.", "The state would prefer for you to do that.", "A girl sits beside a boy.", "Two males are performing.", "Nobody is jumping", "Conrad was being plotted against, to be hit on the head."]}
|
||||
{"query": "A man giving a speech in a restaurant.", "pos": ["A person gives a speech."], "neg": ["The man sits at the table and eats food.", "This is definitely not an endorsement.", "They sold their home because they were retiring and not because of the loan.", "The seal of Missouri is perfect.", "Someone is raising their hand.", "An athlete is competing in the 1500 meter swimming competition.", "Two men watching a magic show."]}
|
||||
{"query": "Indians having a gathering with coats and food and drinks.", "pos": ["A group of Indians are having a gathering with food and drinks"], "neg": ["A group of Indians are having a funeral", "It is only staged on Winter afternoons in Palma's large bullring.", "Right information can empower the legal service practices and the justice system. ", "Meanwhile, the mainland was empty of population.", "Two children is sleeping.", "a fisherman is trying to catch a monkey", "the people are in a train"]}
|
||||
{"query": "A woman with violet hair rides her bicycle outside.", "pos": ["A woman is riding her bike."], "neg": ["A woman is jogging in the park.", "The street was lined with white-painted houses.", "A group watches a movie inside.", "man at picnics cut steak", "Several chefs are sitting down and talking about food.", "The Commission notes that no significant alternatives were considered.", "We ran out of firewood and had to use pine needles for the fire."]}
|
||||
{"query": "A man pulls two women down a city street in a rickshaw.", "pos": ["A man is in a city."], "neg": ["A man is a pilot of an airplane.", "It is boring and mundane.", "The morning sunlight was shining brightly and it was warm. ", "Two people jumped off the dock.", "People watching a spaceship launch.", "Mother Teresa is an easy choice.", "It's worth being able to go at a pace you prefer."]}
|
||||
@ -0,0 +1,10 @@
|
||||
{"query": "Five women walk along a beach wearing flip-flops.", "pos": ["Some women with flip-flops on, are walking along the beach"], "neg": ["The 4 women are sitting on the beach.", "There was a reform in 1996.", "She's not going to court to clear her record.", "The man is talking about hawaii.", "A woman is standing outside.", "The battle was over. ", "A group of people plays volleyball."]}
|
||||
{"query": "A woman standing on a high cliff on one leg looking over a river.", "pos": ["A woman is standing on a cliff."], "neg": ["A woman sits on a chair.", "George Bush told the Republicans there was no way he would let them even consider this foolish idea, against his top advisors advice.", "The family was falling apart.", "no one showed up to the meeting", "A boy is sitting outside playing in the sand.", "Ended as soon as I received the wire.", "A child is reading in her bedroom."]}
|
||||
{"query": "Two woman are playing instruments; one a clarinet, the other a violin.", "pos": ["Some people are playing a tune."], "neg": ["Two women are playing a guitar and drums.", "A man is skiing down a mountain.", "The fatal dose was not taken when the murderer thought it would be.", "Person on bike", "The girl is standing, leaning against the archway.", "A group of women watch soap operas.", "No matter how old people get they never forget. "]}
|
||||
{"query": "A girl with a blue tank top sitting watching three dogs.", "pos": ["A girl is wearing blue."], "neg": ["A girl is with three cats.", "The people are watching a funeral procession.", "The child is wearing black.", "Financing is an issue for us in public schools.", "Kids at a pool.", "It is calming to be assaulted.", "I face a serious problem at eighteen years old. "]}
|
||||
{"query": "A yellow dog running along a forest path.", "pos": ["a dog is running"], "neg": ["a cat is running", "Steele did not keep her original story.", "The rule discourages people to pay their child support.", "A man in a vest sits in a car.", "Person in black clothing, with white bandanna and sunglasses waits at a bus stop.", "Neither the Globe or Mail had comments on the current state of Canada's road system. ", "The Spring Creek facility is old and outdated."]}
|
||||
{"query": "It sets out essential activities in each phase along with critical factors related to those activities.", "pos": ["Critical factors for essential activities are set out."], "neg": ["It lays out critical activities but makes no provision for critical factors related to those activities.", "People are assembled in protest.", "The state would prefer for you to do that.", "A girl sits beside a boy.", "Two males are performing.", "Nobody is jumping", "Conrad was being plotted against, to be hit on the head."]}
|
||||
{"query": "A man giving a speech in a restaurant.", "pos": ["A person gives a speech."], "neg": ["The man sits at the table and eats food.", "This is definitely not an endorsement.", "They sold their home because they were retiring and not because of the loan.", "The seal of Missouri is perfect.", "Someone is raising their hand.", "An athlete is competing in the 1500 meter swimming competition.", "Two men watching a magic show."]}
|
||||
{"query": "Indians having a gathering with coats and food and drinks.", "pos": ["A group of Indians are having a gathering with food and drinks"], "neg": ["A group of Indians are having a funeral", "It is only staged on Winter afternoons in Palma's large bullring.", "Right information can empower the legal service practices and the justice system. ", "Meanwhile, the mainland was empty of population.", "Two children is sleeping.", "a fisherman is trying to catch a monkey", "the people are in a train"]}
|
||||
{"query": "A woman with violet hair rides her bicycle outside.", "pos": ["A woman is riding her bike."], "neg": ["A woman is jogging in the park.", "The street was lined with white-painted houses.", "A group watches a movie inside.", "man at picnics cut steak", "Several chefs are sitting down and talking about food.", "The Commission notes that no significant alternatives were considered.", "We ran out of firewood and had to use pine needles for the fire."]}
|
||||
{"query": "A man pulls two women down a city street in a rickshaw.", "pos": ["A man is in a city."], "neg": ["A man is a pilot of an airplane.", "It is boring and mundane.", "The morning sunlight was shining brightly and it was warm. ", "Two people jumped off the dock.", "People watching a spaceship launch.", "Mother Teresa is an easy choice.", "It's worth being able to go at a pace you prefer."]}
|
||||
90
examples/unified_finetune/unified_finetune_bge-m3_exmaple.sh
Normal file
90
examples/unified_finetune/unified_finetune_bge-m3_exmaple.sh
Normal file
@ -0,0 +1,90 @@
|
||||
#!/bin/bash
|
||||
# Set root path
|
||||
ROOT=/home
|
||||
|
||||
# Set training machines
|
||||
# For more details, refer to https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node
|
||||
HOST_FILE_CONTENT="\
|
||||
localhost slots=8\n\
|
||||
"
|
||||
HOST_FILE=hostfile
|
||||
printf "$HOST_FILE_CONTENT" > $HOST_FILE
|
||||
|
||||
DISTRIBUTED_ARGS="--hostfile $HOST_FILE"
|
||||
|
||||
export LAUNCHER="deepspeed \
|
||||
$DISTRIBUTED_ARGS \
|
||||
"
|
||||
# Set cache directory
|
||||
CACHE_PATH=$ROOT/datasets/.cache
|
||||
|
||||
# Set path of deepspeed config file
|
||||
# For more details, refer to https://huggingface.co/docs/transformers/main_classes/deepspeed#zero
|
||||
DS_CONFIG_FILE=$ROOT/train/ds_config.json
|
||||
|
||||
# Set group size of training
|
||||
GROUP_SIZE=8
|
||||
|
||||
# Set paths of training data. Every path **must be a directory path**.
|
||||
DATA_PATH="
|
||||
$ROOT/datasets/toy_train_data \
|
||||
"
|
||||
|
||||
# Set default batch size for training.
|
||||
# If you want to use effient batching strategy, you should use the script `split_data_by_length.py` to split your data by sequence length firstly. Then the batch size for every batch will depend on its sequence length range, such as len-0-500: 48, len-500-1000: 32, etc., which are defined in `get_file_batch_size()` in `BGE_M3/data.py`.
|
||||
DEFAULT_BATCH_SIZE=8
|
||||
|
||||
# Set number of training epochs.
|
||||
# For more details, refer to https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
|
||||
EPOCHS=1
|
||||
MAX_STEPS=-1
|
||||
|
||||
# Set base model and save path
|
||||
BASE_MODEL=BAAI/bge-m3
|
||||
|
||||
SAVE_PATH=$ROOT/models/bge-m3_finetuned
|
||||
mkdir -p $SAVE_PATH
|
||||
|
||||
# Set learning rate
|
||||
LEARNING_RATE=1e-6
|
||||
|
||||
full_options="
|
||||
--knowledge_distillation True \
|
||||
--output_dir $SAVE_PATH \
|
||||
--model_name_or_path $BASE_MODEL \
|
||||
--normlized True \
|
||||
--temperature 0.02 \
|
||||
--do_train \
|
||||
--train_data $DATA_PATH \
|
||||
--cache_path $CACHE_PATH \
|
||||
--per_device_train_batch_size $DEFAULT_BATCH_SIZE \
|
||||
--query_max_len 512 \
|
||||
--passage_max_len 8192 \
|
||||
--fp16 \
|
||||
--save_steps 1500 \
|
||||
--train_group_size $GROUP_SIZE \
|
||||
--learning_rate $LEARNING_RATE \
|
||||
--num_train_epochs $EPOCHS \
|
||||
--max_steps $MAX_STEPS \
|
||||
--negatives_cross_device False \
|
||||
--logging_steps 10 \
|
||||
--warmup_ratio 0.1 \
|
||||
--weight_decay 0.01 \
|
||||
--overwrite_output_dir True \
|
||||
--gradient_checkpointing \
|
||||
--sentence_pooling_method cls \
|
||||
--same_task_within_batch True \
|
||||
--shuffle_ratio 0.002 \
|
||||
--enable_sub_batch True \
|
||||
--deepspeed ${DS_CONFIG_FILE} \
|
||||
--ddp_timeout 1800 \
|
||||
--unified_finetuning True \
|
||||
--use_self_distill True \
|
||||
--fix_encoder False \
|
||||
"
|
||||
|
||||
run_cmd="$LAUNCHER --module FlagEmbedding.BGE_M3.run ${full_options}"
|
||||
echo ${run_cmd}
|
||||
eval ${run_cmd} 2>&1 | tee $SAVE_PATH/output.log
|
||||
|
||||
set +x
|
||||
Loading…
x
Reference in New Issue
Block a user