mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
update reranker finetune
This commit is contained in:
parent
dbbff43909
commit
4e9d0b386f
@ -121,7 +121,7 @@ class AbsRerankerTrainDataset(Dataset):
|
||||
passages.append(data['neg'][neg_idx])
|
||||
|
||||
if self.args.knowledge_distillation:
|
||||
assert isinstance(data['pos_scores'], list) and isinstance(data['neg_scores', list])
|
||||
assert isinstance(data['pos_scores'], list) and isinstance(data['neg_scores'], list)
|
||||
teacher_scores.append(data['pos_scores'][pos_idx])
|
||||
for neg_idx in neg_idxs:
|
||||
teacher_scores.append(data['neg_scores'][neg_idx])
|
||||
@ -214,7 +214,7 @@ class AbsLLMRerankerTrainDataset(AbsRerankerTrainDataset):
|
||||
passages.append(data['neg'][neg_idx])
|
||||
|
||||
if self.args.knowledge_distillation:
|
||||
assert isinstance(data['pos_scores'], list) and isinstance(data['neg_scores', list])
|
||||
assert isinstance(data['pos_scores'], list) and isinstance(data['neg_scores'], list)
|
||||
teacher_scores.append(data['pos_scores'][pos_idx])
|
||||
for neg_idx in neg_idxs:
|
||||
teacher_scores.append(data['neg_scores'][neg_idx])
|
||||
|
@ -52,14 +52,16 @@ class AbsRerankerModel(ABC, nn.Module):
|
||||
def forward(self, pair: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None, teacher_scores: Optional[Tensor] = None):
|
||||
ranker_logits = self.encode(pair) # (batch_size * num, dim)
|
||||
if teacher_scores is not None:
|
||||
teacher_scores = torch.Tensor(teacher_scores)
|
||||
teacher_targets = teacher_scores.view(self.train_batch_size, -1)
|
||||
teacher_targets = torch.softmax(teacher_scores.detach(), dim=-1)
|
||||
teacher_targets = torch.softmax(teacher_targets.detach(), dim=-1)
|
||||
|
||||
if self.training:
|
||||
grouped_logits = ranker_logits.view(self.train_batch_size, -1)
|
||||
target = torch.zeros(self.train_batch_size, device=grouped_logits.device, dtype=torch.long)
|
||||
loss = self.compute_loss(grouped_logits, target)
|
||||
if teacher_scores is not None:
|
||||
teacher_targets = teacher_targets.to(grouped_logits.device)
|
||||
loss += torch.mean(torch.sum(torch.log_softmax(grouped_logits, dim=-1) * teacher_targets, dim=-1))
|
||||
else:
|
||||
loss = None
|
||||
|
@ -61,8 +61,9 @@ class CrossDecoderModel(AbsRerankerModel):
|
||||
)
|
||||
loss += - torch.mean(torch.sum(torch.log_softmax(student_scores, dim=-1) * teacher_targets, dim=-1))
|
||||
else:
|
||||
teacher_scores = torch.Tensor(teacher_scores)
|
||||
teacher_scores = teacher_scores.view(self.train_batch_size, -1)
|
||||
teacher_targets = torch.softmax(teacher_scores.detach(), dim=-1)
|
||||
teacher_targets = torch.softmax(teacher_scores.detach(), dim=-1).to(ranker_logits[-1].device)
|
||||
for logits in ranker_logits:
|
||||
student_scores = logits.view(
|
||||
self.train_batch_size,
|
||||
|
@ -1,3 +1,63 @@
|
||||
export WANDB_MODE=disabled
|
||||
|
||||
train_data="\
|
||||
../example_data/normal/examples.jsonl "
|
||||
|
||||
# set large epochs and small batch size for testing
|
||||
num_train_epochs=4
|
||||
per_device_train_batch_size=2
|
||||
gradient_accumulation_steps=1
|
||||
train_group_size=8
|
||||
|
||||
# set num_gpus to 2 for testing
|
||||
num_gpus=2
|
||||
|
||||
if [ -z "$HF_HUB_CACHE" ]; then
|
||||
export HF_HUB_CACHE="$HOME/.cache/huggingface/hub"
|
||||
fi
|
||||
|
||||
model_args="\
|
||||
--model_name_or_path BAAI/bge-reranker-base \
|
||||
--cache_dir $HF_HUB_CACHE \
|
||||
"
|
||||
|
||||
data_args="\
|
||||
--train_data $train_data \
|
||||
--cache_path ~/.cache \
|
||||
--train_group_size $train_group_size \
|
||||
--query_max_len 256 \
|
||||
--passage_max_len 256 \
|
||||
--pad_to_multiple_of 8 \
|
||||
--knowledge_distillation True \
|
||||
"
|
||||
|
||||
training_args="\
|
||||
--output_dir ./test_encoder_only_base_bge-reranker-base \
|
||||
--overwrite_output_dir \
|
||||
--learning_rate 6e-5 \
|
||||
--fp16 \
|
||||
--num_train_epochs $num_train_epochs \
|
||||
--per_device_train_batch_size $per_device_train_batch_size \
|
||||
--gradient_accumulation_steps $gradient_accumulation_steps \
|
||||
--dataloader_drop_last True \
|
||||
--warmup_ratio 0.1 \
|
||||
--gradient_checkpointing \
|
||||
--weight_decay 0.01 \
|
||||
--deepspeed ../../ds_stage0.json \
|
||||
--logging_steps 1 \
|
||||
--save_steps 1000 \
|
||||
"
|
||||
|
||||
cmd="torchrun --nproc_per_node $num_gpus \
|
||||
-m FlagEmbedding.finetune.reranker.encoder_only.base \
|
||||
$model_args \
|
||||
$data_args \
|
||||
$training_args \
|
||||
"
|
||||
|
||||
echo $cmd
|
||||
eval $cmd
|
||||
|
||||
torchrun --nproc_per_node 8 \
|
||||
-m FlagEmbedding.finetune.reranker.decoder_only.base \
|
||||
--output_dir ./test \
|
||||
|
@ -1,3 +1,63 @@
|
||||
export WANDB_MODE=disabled
|
||||
|
||||
train_data="\
|
||||
../example_data/normal/examples.jsonl "
|
||||
|
||||
# set large epochs and small batch size for testing
|
||||
num_train_epochs=4
|
||||
per_device_train_batch_size=2
|
||||
gradient_accumulation_steps=1
|
||||
train_group_size=8
|
||||
|
||||
# set num_gpus to 2 for testing
|
||||
num_gpus=2
|
||||
|
||||
if [ -z "$HF_HUB_CACHE" ]; then
|
||||
export HF_HUB_CACHE="$HOME/.cache/huggingface/hub"
|
||||
fi
|
||||
|
||||
model_args="\
|
||||
--model_name_or_path BAAI/bge-reranker-base \
|
||||
--cache_dir $HF_HUB_CACHE \
|
||||
"
|
||||
|
||||
data_args="\
|
||||
--train_data $train_data \
|
||||
--cache_path ~/.cache \
|
||||
--train_group_size $train_group_size \
|
||||
--query_max_len 256 \
|
||||
--passage_max_len 256 \
|
||||
--pad_to_multiple_of 8 \
|
||||
--knowledge_distillation True \
|
||||
"
|
||||
|
||||
training_args="\
|
||||
--output_dir ./test_encoder_only_base_bge-reranker-base \
|
||||
--overwrite_output_dir \
|
||||
--learning_rate 6e-5 \
|
||||
--fp16 \
|
||||
--num_train_epochs $num_train_epochs \
|
||||
--per_device_train_batch_size $per_device_train_batch_size \
|
||||
--gradient_accumulation_steps $gradient_accumulation_steps \
|
||||
--dataloader_drop_last True \
|
||||
--warmup_ratio 0.1 \
|
||||
--gradient_checkpointing \
|
||||
--weight_decay 0.01 \
|
||||
--deepspeed ../../ds_stage0.json \
|
||||
--logging_steps 1 \
|
||||
--save_steps 1000 \
|
||||
"
|
||||
|
||||
cmd="torchrun --nproc_per_node $num_gpus \
|
||||
-m FlagEmbedding.finetune.reranker.encoder_only.base \
|
||||
$model_args \
|
||||
$data_args \
|
||||
$training_args \
|
||||
"
|
||||
|
||||
echo $cmd
|
||||
eval $cmd
|
||||
|
||||
torchrun --nproc_per_node 8 \
|
||||
-m FlagEmbedding.finetune.reranker.decoder_only.layerwise \
|
||||
--output_dir ./test \
|
||||
|
@ -1,20 +1,59 @@
|
||||
torchrun --nproc_per_node 8 \
|
||||
-m FlagEmbedding.finetune.reranker.encoder_only.base \
|
||||
--output_dir ./test \
|
||||
--model_name_or_path BAAI/bge-reranker-base \
|
||||
--train_data /share/chaofan/dataset/mteb_data_new_score/en/fiqa.jsonl \
|
||||
--cache_dir /share/shared_models \
|
||||
--learning_rate 6e-5 \
|
||||
--fp16 \
|
||||
--num_train_epochs 1 \
|
||||
--per_device_train_batch_size 16 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--dataloader_drop_last True \
|
||||
--train_group_size 16 \
|
||||
--query_max_len 256 \
|
||||
--passage_max_len 256 \
|
||||
--weight_decay 0.01 \
|
||||
--logging_steps 10 \
|
||||
--gradient_checkpointing \
|
||||
--cache_path ./data \
|
||||
--deepspeed /share/chaofan/code/stage/stage1.json
|
||||
export WANDB_MODE=disabled
|
||||
|
||||
train_data="\
|
||||
../example_data/normal/examples.jsonl "
|
||||
|
||||
# set large epochs and small batch size for testing
|
||||
num_train_epochs=4
|
||||
per_device_train_batch_size=2
|
||||
gradient_accumulation_steps=1
|
||||
train_group_size=8
|
||||
|
||||
# set num_gpus to 2 for testing
|
||||
num_gpus=2
|
||||
|
||||
if [ -z "$HF_HUB_CACHE" ]; then
|
||||
export HF_HUB_CACHE="$HOME/.cache/huggingface/hub"
|
||||
fi
|
||||
|
||||
model_args="\
|
||||
--model_name_or_path BAAI/bge-reranker-base \
|
||||
--cache_dir $HF_HUB_CACHE \
|
||||
"
|
||||
|
||||
data_args="\
|
||||
--train_data $train_data \
|
||||
--cache_path ~/.cache \
|
||||
--train_group_size $train_group_size \
|
||||
--query_max_len 256 \
|
||||
--passage_max_len 256 \
|
||||
--pad_to_multiple_of 8 \
|
||||
--knowledge_distillation True \
|
||||
"
|
||||
|
||||
training_args="\
|
||||
--output_dir ./test_encoder_only_base_bge-reranker-base \
|
||||
--overwrite_output_dir \
|
||||
--learning_rate 6e-5 \
|
||||
--fp16 \
|
||||
--num_train_epochs $num_train_epochs \
|
||||
--per_device_train_batch_size $per_device_train_batch_size \
|
||||
--gradient_accumulation_steps $gradient_accumulation_steps \
|
||||
--dataloader_drop_last True \
|
||||
--warmup_ratio 0.1 \
|
||||
--gradient_checkpointing \
|
||||
--weight_decay 0.01 \
|
||||
--deepspeed ../../ds_stage0.json \
|
||||
--logging_steps 1 \
|
||||
--save_steps 1000 \
|
||||
"
|
||||
|
||||
cmd="torchrun --nproc_per_node $num_gpus \
|
||||
-m FlagEmbedding.finetune.reranker.encoder_only.base \
|
||||
$model_args \
|
||||
$data_args \
|
||||
$training_args \
|
||||
"
|
||||
|
||||
echo $cmd
|
||||
eval $cmd
|
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user