From 1bfdf2cbb05838fc91248b2d9e863751065f0bcf Mon Sep 17 00:00:00 2001 From: hanhainebula <2512674094@qq.com> Date: Thu, 13 Feb 2025 22:41:54 +0800 Subject: [PATCH] fix "\\n" issue: replace "\\n" with "\n" --- FlagEmbedding/abc/evaluation/arguments.py | 11 +++++++++++ .../abc/finetune/embedder/AbsArguments.py | 7 +++++++ .../abc/finetune/reranker/AbsArguments.py | 15 +++++++++++---- FlagEmbedding/abc/inference/AbsEmbedder.py | 3 ++- FlagEmbedding/abc/inference/AbsReranker.py | 2 ++ FlagEmbedding/evaluation/air_bench/arguments.py | 11 +++++++++++ .../inference/embedder/decoder_only/icl.py | 2 ++ scripts/hn_mine.py | 7 +++++++ 8 files changed, 53 insertions(+), 5 deletions(-) diff --git a/FlagEmbedding/abc/evaluation/arguments.py b/FlagEmbedding/abc/evaluation/arguments.py index 61dfa6f..4f42526 100644 --- a/FlagEmbedding/abc/evaluation/arguments.py +++ b/FlagEmbedding/abc/evaluation/arguments.py @@ -177,3 +177,14 @@ class AbsEvalModelArgs: compress_layers: Optional[int] = field( default=None, metadata={"help": "The compress layers of lightweight reranker.", "nargs": "+"} ) + + def __post_init__(self): + # replace "\\n" with "\n" + if "\\n" in self.query_instruction_format_for_retrieval: + self.query_instruction_format_for_retrieval = self.query_instruction_format_for_retrieval.replace("\\n", "\n") + if "\\n" in self.examples_instruction_format: + self.examples_instruction_format = self.examples_instruction_format.replace("\\n", "\n") + if "\\n" in self.query_instruction_format_for_rerank: + self.query_instruction_format_for_rerank = self.query_instruction_format_for_rerank.replace("\\n", "\n") + if "\\n" in self.passage_instruction_format_for_rerank: + self.passage_instruction_format_for_rerank = self.passage_instruction_format_for_rerank.replace("\\n", "\n") diff --git a/FlagEmbedding/abc/finetune/embedder/AbsArguments.py b/FlagEmbedding/abc/finetune/embedder/AbsArguments.py index d636542..fde2b80 100644 --- a/FlagEmbedding/abc/finetune/embedder/AbsArguments.py +++ b/FlagEmbedding/abc/finetune/embedder/AbsArguments.py @@ -114,6 +114,13 @@ class AbsEmbedderDataArguments: ) def __post_init__(self): + # replace "\\n" with "\n" + if "\\n" in self.query_instruction_format: + self.query_instruction_format = self.query_instruction_format.replace("\\n", "\n") + if "\\n" in self.passage_instruction_format: + self.passage_instruction_format = self.passage_instruction_format.replace("\\n", "\n") + + # check the existence of train data for train_dir in self.train_data: if not os.path.exists(train_dir): raise FileNotFoundError(f"cannot find file: {train_dir}, please set a true path") diff --git a/FlagEmbedding/abc/finetune/reranker/AbsArguments.py b/FlagEmbedding/abc/finetune/reranker/AbsArguments.py index 6013be8..3c6a2e9 100644 --- a/FlagEmbedding/abc/finetune/reranker/AbsArguments.py +++ b/FlagEmbedding/abc/finetune/reranker/AbsArguments.py @@ -119,10 +119,17 @@ class AbsRerankerDataArguments: default='\n', metadata={"help": "The sep token for LLM reranker to discriminate between query and passage"} ) - # def __post_init__(self): - # for train_dir in self.train_data: - # if not os.path.exists(train_dir): - # raise FileNotFoundError(f"cannot find file: {train_dir}, please set a true path") + def __post_init__(self): + # replace "\\n" with "\n" + if "\\n" in self.query_instruction_format: + self.query_instruction_format = self.query_instruction_format.replace("\\n", "\n") + if "\\n" in self.passage_instruction_format: + self.passage_instruction_format = self.passage_instruction_format.replace("\\n", "\n") + + # check the existence of train data + for train_dir in self.train_data: + if not os.path.exists(train_dir): + raise FileNotFoundError(f"cannot find file: {train_dir}, please set a true path") @dataclass diff --git a/FlagEmbedding/abc/inference/AbsEmbedder.py b/FlagEmbedding/abc/inference/AbsEmbedder.py index b838723..c603c19 100644 --- a/FlagEmbedding/abc/inference/AbsEmbedder.py +++ b/FlagEmbedding/abc/inference/AbsEmbedder.py @@ -59,7 +59,6 @@ class AbsEmbedder(ABC): convert_to_numpy: bool = True, **kwargs: Any, ): - query_instruction_format = query_instruction_format.replace('\\n', '\n') self.model_name_or_path = model_name_or_path self.normalize_embeddings = normalize_embeddings self.use_fp16 = use_fp16 @@ -152,6 +151,8 @@ class AbsEmbedder(ABC): Returns: str: The complete sentence with instruction """ + if "\\n" in instruction_format: + instruction_format = instruction_format.replace("\\n", "\n") return instruction_format.format(instruction, sentence) def encode_queries( diff --git a/FlagEmbedding/abc/inference/AbsReranker.py b/FlagEmbedding/abc/inference/AbsReranker.py index 840ae37..b6dc07b 100644 --- a/FlagEmbedding/abc/inference/AbsReranker.py +++ b/FlagEmbedding/abc/inference/AbsReranker.py @@ -149,6 +149,8 @@ class AbsReranker(ABC): Returns: str: The complete sentence with instruction """ + if "\\n" in instruction_format: + instruction_format = instruction_format.replace("\\n", "\n") return instruction_format.format(instruction, sentence) def get_detailed_inputs(self, sentence_pairs: Union[str, List[str]]): diff --git a/FlagEmbedding/evaluation/air_bench/arguments.py b/FlagEmbedding/evaluation/air_bench/arguments.py index 1e924bc..d91e4e9 100644 --- a/FlagEmbedding/evaluation/air_bench/arguments.py +++ b/FlagEmbedding/evaluation/air_bench/arguments.py @@ -102,3 +102,14 @@ class AIRBenchEvalModelArgs: compress_layers: Optional[int] = field( default=None, metadata={"help": "The compress layers of lightweight reranker.", "nargs": "+"} ) + + def __post_init__(self): + # replace "\\n" with "\n" + if "\\n" in self.query_instruction_format_for_retrieval: + self.query_instruction_format_for_retrieval = self.query_instruction_format_for_retrieval.replace("\\n", "\n") + if "\\n" in self.examples_instruction_format: + self.examples_instruction_format = self.examples_instruction_format.replace("\\n", "\n") + if "\\n" in self.query_instruction_format_for_rerank: + self.query_instruction_format_for_rerank = self.query_instruction_format_for_rerank.replace("\\n", "\n") + if "\\n" in self.passage_instruction_format_for_rerank: + self.passage_instruction_format_for_rerank = self.passage_instruction_format_for_rerank.replace("\\n", "\n") diff --git a/FlagEmbedding/inference/embedder/decoder_only/icl.py b/FlagEmbedding/inference/embedder/decoder_only/icl.py index 5ec87e9..53d6902 100644 --- a/FlagEmbedding/inference/embedder/decoder_only/icl.py +++ b/FlagEmbedding/inference/embedder/decoder_only/icl.py @@ -172,6 +172,8 @@ class ICLLLMEmbedder(AbsEmbedder): Returns: str: The complete example following the given format. """ + if "\\n" in instruction_format: + instruction_format = instruction_format.replace("\\n", "\n") return instruction_format.format(instruction, query, response) def stop_self_query_pool(self): diff --git a/scripts/hn_mine.py b/scripts/hn_mine.py index be73625..3092aba 100644 --- a/scripts/hn_mine.py +++ b/scripts/hn_mine.py @@ -90,6 +90,13 @@ class ModelArgs: embedder_passage_max_length: int = field( default=512, metadata={"help": "Max length for passage."} ) + + def __post_init__(self): + # replace "\\n" with "\n" + if "\\n" in self.query_instruction_format_for_retrieval: + self.query_instruction_format_for_retrieval = self.query_instruction_format_for_retrieval.replace("\\n", "\n") + if "\\n" in self.examples_instruction_format: + self.examples_instruction_format = self.examples_instruction_format.replace("\\n", "\n") def create_index(embeddings: np.ndarray, use_gpu: bool = False):