mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
reranker finetune
This commit is contained in:
parent
a39c766965
commit
1d5bf7c518
@ -62,6 +62,7 @@ class AbsRerankerModel(ABC, nn.Module):
|
||||
loss = self.compute_loss(grouped_logits, target)
|
||||
if teacher_scores is not None:
|
||||
teacher_targets = teacher_targets.to(grouped_logits.device)
|
||||
# print(teacher_targets, torch.mean(torch.sum(torch.log_softmax(grouped_logits, dim=-1) * teacher_targets, dim=-1)))
|
||||
loss += torch.mean(torch.sum(torch.log_softmax(grouped_logits, dim=-1) * teacher_targets, dim=-1))
|
||||
else:
|
||||
loss = None
|
||||
|
@ -176,32 +176,9 @@ def save_merged_model(model_args: RerankerModelArguments, output_dir: str):
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args,
|
||||
trust_remote_code=model_args.trust_remote_code)
|
||||
train_config = LayerWiseMiniCPMConfig.from_pretrained(find_largest_checkpoint(output_dir))
|
||||
config.attention_bias = train_config.attention_bias
|
||||
config.attention_dropout = train_config.attention_dropout
|
||||
config.bos_token_id = train_config.bos_token_id
|
||||
config.dim_model_base = train_config.dim_model_base
|
||||
config.eos_token_id = train_config.eos_token_id
|
||||
config.head_multi = train_config.head_multi
|
||||
config.head_type = train_config.head_type
|
||||
config.hidden_act = train_config.hidden_act
|
||||
config.hidden_size = train_config.hidden_size
|
||||
config.initializer_range = train_config.initializer_range
|
||||
config.max_position_embeddings = train_config.max_position_embeddings
|
||||
config.model_type = train_config.model_type
|
||||
config.num_attention_heads = train_config.num_attention_heads
|
||||
config.num_hidden_layers = train_config.num_hidden_layers
|
||||
config.num_key_value_heads = train_config.num_key_value_heads
|
||||
config.pretraining_tp = train_config.pretraining_tp
|
||||
config.rms_norm_eps = train_config.rms_norm_eps
|
||||
config.rope_scaling = train_config.rope_scaling
|
||||
config.rope_theta = train_config.rope_theta
|
||||
config.scale_depth = train_config.scale_depth
|
||||
config.scale_emb = train_config.scale_emb
|
||||
config.start_layer = train_config.start_layer
|
||||
config.transformers_version = train_config.transformers_version
|
||||
config.use_cache = train_config.use_cache
|
||||
config.vocab_size = train_config.vocab_size
|
||||
config.start_layer = model_args.start_layer
|
||||
config.head_multi = model_args.head_multi
|
||||
config.head_type = model_args.head_type
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path,
|
||||
config=config,
|
||||
|
@ -1,10 +1,10 @@
|
||||
export WANDB_MODE=disabled
|
||||
|
||||
train_data="\
|
||||
../example_data/normal/examples.jsonl "
|
||||
../example_data/prompt_based/examples.jsonl "
|
||||
|
||||
# set large epochs and small batch size for testing
|
||||
num_train_epochs=4
|
||||
num_train_epochs=1
|
||||
per_device_train_batch_size=2
|
||||
gradient_accumulation_steps=1
|
||||
train_group_size=8
|
||||
@ -17,24 +17,35 @@ if [ -z "$HF_HUB_CACHE" ]; then
|
||||
fi
|
||||
|
||||
model_args="\
|
||||
--model_name_or_path BAAI/bge-reranker-base \
|
||||
--model_name_or_path BAAI/bge-reranker-v2-gemma \
|
||||
--cache_dir $HF_HUB_CACHE \
|
||||
--use_lora True \
|
||||
--lora_rank 32 \
|
||||
--lora_alpha 64 \
|
||||
--use_flash_attn True \
|
||||
--target_modules q_proj k_proj v_proj o_proj \
|
||||
--save_merged_lora_model True \
|
||||
--model_type decoder \
|
||||
"
|
||||
|
||||
data_args="\
|
||||
--train_data $train_data \
|
||||
--cache_path ~/.cache \
|
||||
--train_group_size $train_group_size \
|
||||
--query_max_len 256 \
|
||||
--passage_max_len 256 \
|
||||
--query_max_len 512 \
|
||||
--passage_max_len 512 \
|
||||
--pad_to_multiple_of 8 \
|
||||
--knowledge_distillation True \
|
||||
--query_instruction_for_retrieval 'A: ' \
|
||||
--query_instruction_format '{}{}' \
|
||||
--passage_instruction_for_retrieval 'B: ' \
|
||||
--passage_instruction_format '{}{}' \
|
||||
"
|
||||
|
||||
training_args="\
|
||||
--output_dir ./test_encoder_only_base_bge-reranker-base \
|
||||
--output_dir ./test_decoder_only_base_bge-reranker-v2-gemma \
|
||||
--overwrite_output_dir \
|
||||
--learning_rate 6e-5 \
|
||||
--learning_rate 2e-4 \
|
||||
--fp16 \
|
||||
--num_train_epochs $num_train_epochs \
|
||||
--per_device_train_batch_size $per_device_train_batch_size \
|
||||
@ -49,42 +60,12 @@ training_args="\
|
||||
"
|
||||
|
||||
cmd="torchrun --nproc_per_node $num_gpus \
|
||||
-m FlagEmbedding.finetune.reranker.encoder_only.base \
|
||||
--master_port=4567 \
|
||||
-m FlagEmbedding.finetune.reranker.decoder_only.base \
|
||||
$model_args \
|
||||
$data_args \
|
||||
$training_args \
|
||||
"
|
||||
|
||||
echo $cmd
|
||||
eval $cmd
|
||||
|
||||
torchrun --nproc_per_node 8 \
|
||||
-m FlagEmbedding.finetune.reranker.decoder_only.base \
|
||||
--output_dir ./test \
|
||||
--model_name_or_path BAAI/bge-reranker-v2-gemma \
|
||||
--train_data /share/chaofan/dataset/mteb_data_new_score/en/fiqa.jsonl \
|
||||
--cache_dir /share/shared_models \
|
||||
--learning_rate 2e-4 \
|
||||
--num_train_epochs 1 \
|
||||
--max_steps 5 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 16 \
|
||||
--dataloader_drop_last True \
|
||||
--gradient_checkpointing \
|
||||
--query_max_len 512 \
|
||||
--passage_max_len 512 \
|
||||
--train_group_size 16 \
|
||||
--logging_steps 1 \
|
||||
--save_total_limit 50 \
|
||||
--fp16 \
|
||||
--dataloader_drop_last True \
|
||||
--weight_decay 0.01 \
|
||||
--cache_path ./data \
|
||||
--use_lora True \
|
||||
--lora_rank 32 \
|
||||
--lora_alpha 64 \
|
||||
--use_flash_attn True \
|
||||
--target_modules q_proj k_proj v_proj o_proj \
|
||||
--save_merged_lora_model True \
|
||||
--model_type decoder \
|
||||
--deepspeed /share/chaofan/code/stage/stage1.json
|
||||
eval $cmd
|
@ -1,10 +1,10 @@
|
||||
export WANDB_MODE=disabled
|
||||
|
||||
train_data="\
|
||||
../example_data/normal/examples.jsonl "
|
||||
../example_data/prompt_based/examples.jsonl "
|
||||
|
||||
# set large epochs and small batch size for testing
|
||||
num_train_epochs=4
|
||||
num_train_epochs=1
|
||||
per_device_train_batch_size=2
|
||||
gradient_accumulation_steps=1
|
||||
train_group_size=8
|
||||
@ -17,24 +17,40 @@ if [ -z "$HF_HUB_CACHE" ]; then
|
||||
fi
|
||||
|
||||
model_args="\
|
||||
--model_name_or_path BAAI/bge-reranker-base \
|
||||
--model_name_or_path /share/chaofan/models/minicpm-2b-fp32-dpo \
|
||||
--cache_dir $HF_HUB_CACHE \
|
||||
--use_lora True \
|
||||
--lora_rank 32 \
|
||||
--lora_alpha 64 \
|
||||
--use_flash_attn True \
|
||||
--target_modules q_proj k_proj v_proj o_proj \
|
||||
--save_merged_lora_model True \
|
||||
--model_type decoder \
|
||||
--model_type from_raw_model \
|
||||
--start_layer 8 \
|
||||
--head_multi True \
|
||||
--head_type simple \
|
||||
--trust_remote_code True \
|
||||
"
|
||||
|
||||
data_args="\
|
||||
--train_data $train_data \
|
||||
--cache_path ~/.cache \
|
||||
--train_group_size $train_group_size \
|
||||
--query_max_len 256 \
|
||||
--passage_max_len 256 \
|
||||
--query_max_len 512 \
|
||||
--passage_max_len 512 \
|
||||
--pad_to_multiple_of 8 \
|
||||
--knowledge_distillation True \
|
||||
--query_instruction_for_retrieval 'A: ' \
|
||||
--query_instruction_format '{}{}' \
|
||||
--passage_instruction_for_retrieval 'B: ' \
|
||||
--passage_instruction_format '{}{}' \
|
||||
"
|
||||
|
||||
training_args="\
|
||||
--output_dir ./test_encoder_only_base_bge-reranker-base \
|
||||
--output_dir ./test_decoder_only_base_bge-reranker-v2-layerwise-minicpm \
|
||||
--overwrite_output_dir \
|
||||
--learning_rate 6e-5 \
|
||||
--learning_rate 2e-4 \
|
||||
--fp16 \
|
||||
--num_train_epochs $num_train_epochs \
|
||||
--per_device_train_batch_size $per_device_train_batch_size \
|
||||
@ -49,7 +65,8 @@ training_args="\
|
||||
"
|
||||
|
||||
cmd="torchrun --nproc_per_node $num_gpus \
|
||||
-m FlagEmbedding.finetune.reranker.encoder_only.base \
|
||||
--master_port=4567 \
|
||||
-m FlagEmbedding.finetune.reranker.decoder_only.layerwise \
|
||||
$model_args \
|
||||
$data_args \
|
||||
$training_args \
|
||||
@ -58,38 +75,3 @@ cmd="torchrun --nproc_per_node $num_gpus \
|
||||
echo $cmd
|
||||
eval $cmd
|
||||
|
||||
torchrun --nproc_per_node 8 \
|
||||
-m FlagEmbedding.finetune.reranker.decoder_only.layerwise \
|
||||
--output_dir ./test \
|
||||
--model_name_or_path /share/chaofan/models/minicpm-2b-fp32-dpo \
|
||||
--train_data /share/chaofan/dataset/mteb_data_new_score/en/fiqa.jsonl \
|
||||
--cache_dir /share/shared_models \
|
||||
--learning_rate 2e-4 \
|
||||
--num_train_epochs 1 \
|
||||
--max_steps 5 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 16 \
|
||||
--dataloader_drop_last True \
|
||||
--gradient_checkpointing \
|
||||
--query_max_len 512 \
|
||||
--passage_max_len 512 \
|
||||
--train_group_size 16 \
|
||||
--logging_steps 1 \
|
||||
--save_total_limit 50 \
|
||||
--fp16 \
|
||||
--dataloader_drop_last True \
|
||||
--weight_decay 0.01 \
|
||||
--cache_path ./data \
|
||||
--use_lora True \
|
||||
--lora_rank 32 \
|
||||
--lora_alpha 64 \
|
||||
--use_flash_attn True \
|
||||
--target_modules q_proj k_proj v_proj o_proj linear_head \
|
||||
--save_merged_lora_model True \
|
||||
--model_type decoder \
|
||||
--deepspeed /share/chaofan/code/stage/stage1.json \
|
||||
--model_type from_raw_model \
|
||||
--start_layer 8 \
|
||||
--head_multi True \
|
||||
--head_type simple \
|
||||
--trust_remote_code True
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user