update reranker finetune

This commit is contained in:
cfli 2024-10-21 16:13:13 +08:00
parent dbbff43909
commit 4e9d0b386f
7 changed files with 196 additions and 24 deletions

View File

@ -121,7 +121,7 @@ class AbsRerankerTrainDataset(Dataset):
passages.append(data['neg'][neg_idx])
if self.args.knowledge_distillation:
assert isinstance(data['pos_scores'], list) and isinstance(data['neg_scores', list])
assert isinstance(data['pos_scores'], list) and isinstance(data['neg_scores'], list)
teacher_scores.append(data['pos_scores'][pos_idx])
for neg_idx in neg_idxs:
teacher_scores.append(data['neg_scores'][neg_idx])
@ -214,7 +214,7 @@ class AbsLLMRerankerTrainDataset(AbsRerankerTrainDataset):
passages.append(data['neg'][neg_idx])
if self.args.knowledge_distillation:
assert isinstance(data['pos_scores'], list) and isinstance(data['neg_scores', list])
assert isinstance(data['pos_scores'], list) and isinstance(data['neg_scores'], list)
teacher_scores.append(data['pos_scores'][pos_idx])
for neg_idx in neg_idxs:
teacher_scores.append(data['neg_scores'][neg_idx])

View File

@ -52,14 +52,16 @@ class AbsRerankerModel(ABC, nn.Module):
def forward(self, pair: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None, teacher_scores: Optional[Tensor] = None):
ranker_logits = self.encode(pair) # (batch_size * num, dim)
if teacher_scores is not None:
teacher_scores = torch.Tensor(teacher_scores)
teacher_targets = teacher_scores.view(self.train_batch_size, -1)
teacher_targets = torch.softmax(teacher_scores.detach(), dim=-1)
teacher_targets = torch.softmax(teacher_targets.detach(), dim=-1)
if self.training:
grouped_logits = ranker_logits.view(self.train_batch_size, -1)
target = torch.zeros(self.train_batch_size, device=grouped_logits.device, dtype=torch.long)
loss = self.compute_loss(grouped_logits, target)
if teacher_scores is not None:
teacher_targets = teacher_targets.to(grouped_logits.device)
loss += torch.mean(torch.sum(torch.log_softmax(grouped_logits, dim=-1) * teacher_targets, dim=-1))
else:
loss = None

View File

@ -61,8 +61,9 @@ class CrossDecoderModel(AbsRerankerModel):
)
loss += - torch.mean(torch.sum(torch.log_softmax(student_scores, dim=-1) * teacher_targets, dim=-1))
else:
teacher_scores = torch.Tensor(teacher_scores)
teacher_scores = teacher_scores.view(self.train_batch_size, -1)
teacher_targets = torch.softmax(teacher_scores.detach(), dim=-1)
teacher_targets = torch.softmax(teacher_scores.detach(), dim=-1).to(ranker_logits[-1].device)
for logits in ranker_logits:
student_scores = logits.view(
self.train_batch_size,

View File

@ -1,3 +1,63 @@
export WANDB_MODE=disabled
train_data="\
../example_data/normal/examples.jsonl "
# set large epochs and small batch size for testing
num_train_epochs=4
per_device_train_batch_size=2
gradient_accumulation_steps=1
train_group_size=8
# set num_gpus to 2 for testing
num_gpus=2
if [ -z "$HF_HUB_CACHE" ]; then
export HF_HUB_CACHE="$HOME/.cache/huggingface/hub"
fi
model_args="\
--model_name_or_path BAAI/bge-reranker-base \
--cache_dir $HF_HUB_CACHE \
"
data_args="\
--train_data $train_data \
--cache_path ~/.cache \
--train_group_size $train_group_size \
--query_max_len 256 \
--passage_max_len 256 \
--pad_to_multiple_of 8 \
--knowledge_distillation True \
"
training_args="\
--output_dir ./test_encoder_only_base_bge-reranker-base \
--overwrite_output_dir \
--learning_rate 6e-5 \
--fp16 \
--num_train_epochs $num_train_epochs \
--per_device_train_batch_size $per_device_train_batch_size \
--gradient_accumulation_steps $gradient_accumulation_steps \
--dataloader_drop_last True \
--warmup_ratio 0.1 \
--gradient_checkpointing \
--weight_decay 0.01 \
--deepspeed ../../ds_stage0.json \
--logging_steps 1 \
--save_steps 1000 \
"
cmd="torchrun --nproc_per_node $num_gpus \
-m FlagEmbedding.finetune.reranker.encoder_only.base \
$model_args \
$data_args \
$training_args \
"
echo $cmd
eval $cmd
torchrun --nproc_per_node 8 \
-m FlagEmbedding.finetune.reranker.decoder_only.base \
--output_dir ./test \

View File

@ -1,3 +1,63 @@
export WANDB_MODE=disabled
train_data="\
../example_data/normal/examples.jsonl "
# set large epochs and small batch size for testing
num_train_epochs=4
per_device_train_batch_size=2
gradient_accumulation_steps=1
train_group_size=8
# set num_gpus to 2 for testing
num_gpus=2
if [ -z "$HF_HUB_CACHE" ]; then
export HF_HUB_CACHE="$HOME/.cache/huggingface/hub"
fi
model_args="\
--model_name_or_path BAAI/bge-reranker-base \
--cache_dir $HF_HUB_CACHE \
"
data_args="\
--train_data $train_data \
--cache_path ~/.cache \
--train_group_size $train_group_size \
--query_max_len 256 \
--passage_max_len 256 \
--pad_to_multiple_of 8 \
--knowledge_distillation True \
"
training_args="\
--output_dir ./test_encoder_only_base_bge-reranker-base \
--overwrite_output_dir \
--learning_rate 6e-5 \
--fp16 \
--num_train_epochs $num_train_epochs \
--per_device_train_batch_size $per_device_train_batch_size \
--gradient_accumulation_steps $gradient_accumulation_steps \
--dataloader_drop_last True \
--warmup_ratio 0.1 \
--gradient_checkpointing \
--weight_decay 0.01 \
--deepspeed ../../ds_stage0.json \
--logging_steps 1 \
--save_steps 1000 \
"
cmd="torchrun --nproc_per_node $num_gpus \
-m FlagEmbedding.finetune.reranker.encoder_only.base \
$model_args \
$data_args \
$training_args \
"
echo $cmd
eval $cmd
torchrun --nproc_per_node 8 \
-m FlagEmbedding.finetune.reranker.decoder_only.layerwise \
--output_dir ./test \

View File

@ -1,20 +1,59 @@
torchrun --nproc_per_node 8 \
-m FlagEmbedding.finetune.reranker.encoder_only.base \
--output_dir ./test \
--model_name_or_path BAAI/bge-reranker-base \
--train_data /share/chaofan/dataset/mteb_data_new_score/en/fiqa.jsonl \
--cache_dir /share/shared_models \
--learning_rate 6e-5 \
--fp16 \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 4 \
--dataloader_drop_last True \
--train_group_size 16 \
--query_max_len 256 \
--passage_max_len 256 \
--weight_decay 0.01 \
--logging_steps 10 \
--gradient_checkpointing \
--cache_path ./data \
--deepspeed /share/chaofan/code/stage/stage1.json
export WANDB_MODE=disabled
train_data="\
../example_data/normal/examples.jsonl "
# set large epochs and small batch size for testing
num_train_epochs=4
per_device_train_batch_size=2
gradient_accumulation_steps=1
train_group_size=8
# set num_gpus to 2 for testing
num_gpus=2
if [ -z "$HF_HUB_CACHE" ]; then
export HF_HUB_CACHE="$HOME/.cache/huggingface/hub"
fi
model_args="\
--model_name_or_path BAAI/bge-reranker-base \
--cache_dir $HF_HUB_CACHE \
"
data_args="\
--train_data $train_data \
--cache_path ~/.cache \
--train_group_size $train_group_size \
--query_max_len 256 \
--passage_max_len 256 \
--pad_to_multiple_of 8 \
--knowledge_distillation True \
"
training_args="\
--output_dir ./test_encoder_only_base_bge-reranker-base \
--overwrite_output_dir \
--learning_rate 6e-5 \
--fp16 \
--num_train_epochs $num_train_epochs \
--per_device_train_batch_size $per_device_train_batch_size \
--gradient_accumulation_steps $gradient_accumulation_steps \
--dataloader_drop_last True \
--warmup_ratio 0.1 \
--gradient_checkpointing \
--weight_decay 0.01 \
--deepspeed ../../ds_stage0.json \
--logging_steps 1 \
--save_steps 1000 \
"
cmd="torchrun --nproc_per_node $num_gpus \
-m FlagEmbedding.finetune.reranker.encoder_only.base \
$model_args \
$data_args \
$training_args \
"
echo $cmd
eval $cmd

File diff suppressed because one or more lines are too long