mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
add instruction for fine-tuning
This commit is contained in:
parent
50a40621ed
commit
938cf8d5e7
@ -57,12 +57,6 @@ Train data should be a json file, where each line is a dict like this:
|
||||
`query` is the query, and `pos` is a list of positive texts, `neg` is a list of negative texts.
|
||||
If you have no negative texts for a query, you can random sample some from the entire corpus as the negatives.
|
||||
|
||||
Besides, if you want to add instruction, you should add it to text in this file:
|
||||
```
|
||||
{"query": your_instruction + str, "pos": List[str], "neg":List[str]}
|
||||
```
|
||||
Noted that use your instruction as the value of argument `query_instruction_for_retrieval` if add a query instruction, otherwise set `query_instruction_for_retrieval=""`.
|
||||
|
||||
See [examples/finetune](../../examples/finetune) for a toy data and training example.
|
||||
|
||||
|
||||
@ -75,7 +69,8 @@ python -m FlagEmbedding.baai_general_embedding.finetune.hn_mine \
|
||||
--model_name_or_path BAAI/bge-base-en \
|
||||
--input_file toy_finetune_data.jsonl \
|
||||
--output_file toy_finetune_data_minedHN.jsonl \
|
||||
--range_for_sampling 2-200
|
||||
--range_for_sampling 2-200 \
|
||||
--query_instruction_for_retrieval "Represent this sentence for searching relevant passages: "
|
||||
```
|
||||
|
||||
- `input_file`: json data for finetuning. This script will retrieval top-k documents for each query,
|
||||
@ -91,19 +86,19 @@ The format of this file is the same as pretrain data. If input a candidate_pool,
|
||||
torchrun --nproc_per_node {number of gpus} \
|
||||
-m FlagEmbedding.baai_general_embedding.finetune.run \
|
||||
--output_dir {path to save model} \
|
||||
--model_name_or_path BAAI/bge-large-en \
|
||||
--model_name_or_path BAAI/bge-large-zh \
|
||||
--train_data ./toy_finetune_data.jsonl \
|
||||
--learning_rate 1e-5 \
|
||||
--fp16 \
|
||||
--num_train_epochs 5 \
|
||||
--per_device_train_batch_size {batch size} \
|
||||
--dataloader_drop_last True \
|
||||
--per_device_train_batch_size {large batch size; set 1 for toy data} \
|
||||
--normlized True \
|
||||
--temperature 0.02 \
|
||||
--query_max_len 32 \
|
||||
--passage_max_len 128 \
|
||||
--query_max_len 64 \
|
||||
--passage_max_len 256 \
|
||||
--train_group_size 2 \
|
||||
--negatives_cross_device
|
||||
--negatives_cross_device \
|
||||
--query_instruction_for_retrieval "为这个句子生成表示以用于检索相关文章:"
|
||||
```
|
||||
|
||||
**some important arguments**:
|
||||
@ -116,8 +111,9 @@ Besides the negatives in this group, the in-batch negatives also will be used in
|
||||
- `learning_rate`: select a appropriate for your model. Recommend 1e-5/2e-5/3e-5 for large/base/small-scale.
|
||||
- `temperature`: the similarity will be `simi = simi/temperature` before using them to compute loss.
|
||||
A higher temperature can reduce the value of similarity between texts in downstream tasks.
|
||||
- `query_max_len`: max length for query
|
||||
- `passage_max_len`: max length for passage
|
||||
- `query_max_len`: max length for query. Please set it according the average length of queries in your data.
|
||||
- `passage_max_len`: max length for passage. Please set it according the average length of passages in your data.
|
||||
- `query_instruction_for_retrieval`: instruction for query, which will be added to each query.
|
||||
|
||||
More training arguments please refer to [transformers.TrainingArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments)
|
||||
|
||||
@ -125,27 +121,4 @@ More training arguments please refer to [transformers.TrainingArguments](https:/
|
||||
### 3. Load your model
|
||||
After fine-tuning BGE model, you can load it easily in the same way as [here(with FlagModel)](https://github.com/FlagOpen/FlagEmbedding#using-flagembedding) / [(with transformers)](https://github.com/FlagOpen/FlagEmbedding#using-huggingface-transformers).
|
||||
|
||||
Please replace the `query_instruction_for_retrieval` with your instruction if you add a instruction for query in your data json.
|
||||
|
||||
If you don't add instruction for query in your data, please set `query_instruction_for_retrieval` to be a `""`.
|
||||
|
||||
```python
|
||||
from FlagEmbedding import FlagModel
|
||||
model = FlagModel(your_model, query_instruction_for_retrieval="")
|
||||
|
||||
queries = ['query_1', 'query_2']
|
||||
passages = ["样例文档-1", "样例文档-2"]
|
||||
q_embeddings = model.encode_queries(queries)
|
||||
p_embeddings = model.encode(passages)
|
||||
scores = q_embeddings @ p_embeddings.T
|
||||
```
|
||||
|
||||
If you want to load your fine-tuned models with `sentence_transformers`, you should **set the pooling_mode to be `cls`** (the default pooling method in sentence_transformers is mean pooling).
|
||||
You can load your model like this:
|
||||
```python
|
||||
from sentence_transformers import SentenceTransformer, models
|
||||
|
||||
word_embedding_model = models.Transformer(finetuned_model_path, max_seq_length=512, do_lower_case=True)
|
||||
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode='cls')
|
||||
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
|
||||
```
|
||||
Please replace the `query_instruction_for_retrieval` with your instruction if you set a different value for hyper-parameter `--query_instruction_for_retrieval` when fine-tuning.
|
@ -55,6 +55,13 @@ class DataArguments:
|
||||
default=100000000, metadata={"help": "the max number of examples for each dataset"}
|
||||
)
|
||||
|
||||
query_instruction_for_retrieval: str= field(
|
||||
default=None, metadata={"help": "instruction for query"}
|
||||
)
|
||||
passage_instruction_for_retrieval: str = field(
|
||||
default=None, metadata={"help": "instruction for passage"}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if not os.path.exists(self.train_data):
|
||||
raise FileNotFoundError(f"cannot find file: {self.train_data}, please set a true path")
|
||||
@ -62,5 +69,5 @@ class DataArguments:
|
||||
@dataclass
|
||||
class RetrieverTrainingArguments(TrainingArguments):
|
||||
negatives_cross_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
|
||||
temperature: Optional[float] = field(default=0.01)
|
||||
temperature: Optional[float] = field(default=0.02)
|
||||
fix_position_embedding: bool = field(default=False, metadata={"help": "Freeze the parameters of position embeddings"})
|
||||
|
@ -40,6 +40,9 @@ class TrainDatasetForEmbedding(Dataset):
|
||||
|
||||
def __getitem__(self, item) -> Tuple[BatchEncoding, List[BatchEncoding]]:
|
||||
query = self.dataset[item]['query']
|
||||
if self.args.query_instruction_for_retrieval is not None:
|
||||
query = self.args.query_instruction_for_retrieval + query
|
||||
|
||||
passages = []
|
||||
pos = random.choice(self.dataset[item]['pos'])
|
||||
passages.append(pos)
|
||||
@ -51,6 +54,8 @@ class TrainDatasetForEmbedding(Dataset):
|
||||
negs = random.sample(self.dataset[item]['neg'], self.args.train_group_size - 1)
|
||||
passages.extend(negs)
|
||||
|
||||
if self.args.passage_instruction_for_retrieval is not None:
|
||||
passages = [self.args.passage_instruction_for_retrieval+p for p in passages]
|
||||
return query, passages
|
||||
|
||||
|
||||
|
@ -17,6 +17,7 @@ def get_args():
|
||||
parser.add_argument('--range_for_sampling', default=None, type=str, help="range to sample negatives")
|
||||
parser.add_argument('--use_gpu_for_searching', action='store_true', help='use faiss-gpu')
|
||||
parser.add_argument('--negative_number', default=15, help='use faiss-gpu')
|
||||
parser.add_argument('--query_instruction_for_retrieval', default="")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@ -70,7 +71,7 @@ def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, n
|
||||
print(f'inferencing embedding for corpus (number={len(corpus)})--------------')
|
||||
p_vecs = model.encode(corpus, batch_size=256)
|
||||
print(f'inferencing embedding for queries (number={len(queries)})--------------')
|
||||
q_vecs = model.encode(queries, batch_size=256)
|
||||
q_vecs = model.encode_queries(queries, batch_size=256)
|
||||
|
||||
print('creat index and search------------------')
|
||||
index = create_index(p_vecs, use_gpu=use_gpu)
|
||||
@ -102,7 +103,7 @@ if __name__ == '__main__':
|
||||
sample_range = args.range_for_sampling.split('-')
|
||||
sample_range = [int(x) for x in sample_range]
|
||||
|
||||
model = FlagModel(args.model_name_or_path)
|
||||
model = FlagModel(args.model_name_or_path, query_instruction_for_retrieval=args.query_instruction_for_retrieval)
|
||||
|
||||
find_knn_neg(model,
|
||||
input_file=args.input_file,
|
||||
|
@ -73,6 +73,7 @@ def main():
|
||||
sentence_pooling_method=model_args.sentence_pooling_method,
|
||||
negatives_cross_device=training_args.negatives_cross_device,
|
||||
temperature=training_args.temperature)
|
||||
training_args.sentence_pooling_method = model_args.sentence_pooling_method
|
||||
|
||||
if training_args.fix_position_embedding:
|
||||
for k, v in model.named_parameters():
|
||||
|
@ -1,4 +1,12 @@
|
||||
from transformers.trainer import *
|
||||
from sentence_transformers import SentenceTransformer, models
|
||||
|
||||
|
||||
def save_ckpt_for_sentence_transformers(ckpt_dir, pooling_mode: str = 'cls'):
|
||||
word_embedding_model = models.Transformer(ckpt_dir)
|
||||
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode=pooling_mode)
|
||||
model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device='cpu')
|
||||
model.save(ckpt_dir)
|
||||
|
||||
|
||||
class BiTrainer(Trainer):
|
||||
@ -17,9 +25,12 @@ class BiTrainer(Trainer):
|
||||
if self.tokenizer is not None and self.is_world_process_zero():
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
||||
|
||||
# save the checkpoint for sentence-transformers library
|
||||
if self.is_world_process_zero():
|
||||
save_ckpt_for_sentence_transformers(output_dir, pooling_mode=self.args.sentence_pooling_method)
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
"""
|
||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||
|
@ -110,7 +110,7 @@ p_embeddings = model.encode(passages)
|
||||
scores = q_embeddings @ p_embeddings.T
|
||||
```
|
||||
For the value of the argument `query_instruction_for_retrieval`, see [Model List](https://github.com/FlagOpen/FlagEmbedding/tree/master#model-list).
|
||||
To load your fine-tuned model, use your instruction if you add it during fine-tuning. Set it to an empty string `""` if you don't add an instruction to the query in your json file.
|
||||
To load your fine-tuned model, use your instruction if you add it during fine-tuning.
|
||||
|
||||
By default, FlagModel will use all available GPUs when encoding. Please set `os.environ["CUDA_VISIBLE_DEVICES"]` to select specific GPUs.
|
||||
You also can set `os.environ["CUDA_VISIBLE_DEVICES"]=""` to make all GPUs unavailable.
|
||||
@ -147,8 +147,6 @@ q_embeddings = model.encode([instruction+q for q in queries], normalize_embeddin
|
||||
p_embeddings = model.encode(passages, normalize_embeddings=True)
|
||||
scores = q_embeddings @ p_embeddings.T
|
||||
```
|
||||
If you want to load your fine-tuned models, see [here](https://github.com/FlagOpen/FlagEmbedding/tree/master/FlagEmbedding/baai_general_embedding#3-load-your-model).
|
||||
And use your instruction if you add it during fine-tuning. Set it to an empty string `""` if you don't add an instruction to the query in your json file.
|
||||
|
||||
#### Using Langchain
|
||||
|
||||
@ -162,8 +160,9 @@ model = HuggingFaceBgeEmbeddings(
|
||||
model_name=model_name,
|
||||
model_kwargs=model_kwargs,
|
||||
encode_kwargs=encode_kwargs,
|
||||
query_instruction="为这个句子生成表示以用于检索相关文章:"
|
||||
query_instruction="为这个句子生成表示以用于检索相关文章:"
|
||||
)
|
||||
model.query_instruction = "为这个句子生成表示以用于检索相关文章:"
|
||||
```
|
||||
|
||||
|
||||
|
@ -29,12 +29,6 @@ Train data should be a json file, where each line is a dict like this:
|
||||
`query` is the query, and `pos` is a list of positive texts, `neg` is a list of negative texts.
|
||||
If you have no negative texts for a query, you can random sample some from the entire corpus as the negatives.
|
||||
|
||||
Besides, if you want to add instruction, you should add it to text in this file:
|
||||
```
|
||||
{"query": your_instruction + str, "pos": List[str], "neg":List[str]}
|
||||
```
|
||||
Noted that use your instruction as the value of argument `query_instruction_for_retrieval` if add a query instruction, otherwise set `query_instruction_for_retrieval=""`.
|
||||
|
||||
See [toy_finetune_data.jsonl]() for a toy data file.
|
||||
|
||||
**Hard Negatives**
|
||||
@ -46,7 +40,8 @@ python -m FlagEmbedding.baai_general_embedding.finetune.hn_mine \
|
||||
--model_name_or_path BAAI/bge-base-en \
|
||||
--input_file toy_finetune_data.jsonl \
|
||||
--output_file toy_finetune_data_minedHN.jsonl \
|
||||
--range_for_sampling 2-200
|
||||
--range_for_sampling 2-200 \
|
||||
--query_instruction_for_retrieval "Represent this sentence for searching relevant passages: "
|
||||
```
|
||||
|
||||
- `input_file`: json data for finetuning. This script will retrieval top-k documents for each query,
|
||||
@ -72,10 +67,11 @@ torchrun --nproc_per_node {number of gpus} \
|
||||
--dataloader_drop_last True \
|
||||
--normlized True \
|
||||
--temperature 0.02 \
|
||||
--query_max_len 32 \
|
||||
--passage_max_len 128 \
|
||||
--query_max_len 64 \
|
||||
--passage_max_len 256 \
|
||||
--train_group_size 2 \
|
||||
--negatives_cross_device
|
||||
--negatives_cross_device \
|
||||
--query_instruction_for_retrieval "为这个句子生成表示以用于检索相关文章:"
|
||||
```
|
||||
|
||||
**some important arguments**:
|
||||
@ -88,8 +84,9 @@ Besides the negatives in this group, the in-batch negatives also will be used in
|
||||
- `learning_rate`: select a appropriate for your model. Recommend 1e-5/2e-5/3e-5 for large/base/small-scale.
|
||||
- `temperature`: the similarity will be `simi = simi/temperature` before using them to compute loss.
|
||||
A higher temperature can reduce the value of similarity between texts in downstream tasks.
|
||||
- `query_max_len`: max length for query
|
||||
- `passage_max_len`: max length for passage
|
||||
- `query_max_len`: max length for query. Please set it according the average length of queries in your data.
|
||||
- `passage_max_len`: max length for passage. Please set it according the average length of passages in your data.
|
||||
- `query_instruction_for_retrieval`: instruction for query, which will be added to each query.
|
||||
|
||||
More training arguments please refer to [transformers.TrainingArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments)
|
||||
|
||||
@ -97,28 +94,4 @@ More training arguments please refer to [transformers.TrainingArguments](https:/
|
||||
### 3. Load your model
|
||||
After fine-tuning BGE model, you can load it easily in the same way as [here(with FlagModel)](https://github.com/FlagOpen/FlagEmbedding#using-flagembedding) / [(with transformers)](https://github.com/FlagOpen/FlagEmbedding#using-huggingface-transformers).
|
||||
|
||||
Please replace the `query_instruction_for_retrieval` with your instruction if you add a instruction for query in your data json.
|
||||
|
||||
If you don't add instruction for query in your data, please set `query_instruction_for_retrieval` to be a `""`.
|
||||
|
||||
```python
|
||||
from FlagEmbedding import FlagModel
|
||||
model = FlagModel(your_model, query_instruction_for_retrieval="")
|
||||
|
||||
queries = ['query_1', 'query_2']
|
||||
passages = ["样例文档-1", "样例文档-2"]
|
||||
q_embeddings = model.encode_queries(queries)
|
||||
p_embeddings = model.encode(passages)
|
||||
scores = q_embeddings @ p_embeddings.T
|
||||
```
|
||||
|
||||
If you want to load your fine-tuned models with `sentence_transformers`, you should **set the pooling_mode to be `cls`** (the default pooling method in sentence_transformers is mean pooling).
|
||||
You can load your model like this:
|
||||
```python
|
||||
from sentence_transformers import SentenceTransformer, models
|
||||
|
||||
word_embedding_model = models.Transformer(finetuned_model_path, max_seq_length=512, do_lower_case=True)
|
||||
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode='cls')
|
||||
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
|
||||
```
|
||||
|
||||
Please replace the `query_instruction_for_retrieval` with your instruction if you set a different value for hyper-parameter `--query_instruction_for_retrieval` when fine-tuning.
|
||||
|
Loading…
x
Reference in New Issue
Block a user