fix "\\n" issue: replace "\\n" with "\n"

This commit is contained in:
hanhainebula 2025-02-13 22:41:54 +08:00
parent 317aee47bf
commit 1bfdf2cbb0
8 changed files with 53 additions and 5 deletions

View File

@ -177,3 +177,14 @@ class AbsEvalModelArgs:
compress_layers: Optional[int] = field(
default=None, metadata={"help": "The compress layers of lightweight reranker.", "nargs": "+"}
)
def __post_init__(self):
# replace "\\n" with "\n"
if "\\n" in self.query_instruction_format_for_retrieval:
self.query_instruction_format_for_retrieval = self.query_instruction_format_for_retrieval.replace("\\n", "\n")
if "\\n" in self.examples_instruction_format:
self.examples_instruction_format = self.examples_instruction_format.replace("\\n", "\n")
if "\\n" in self.query_instruction_format_for_rerank:
self.query_instruction_format_for_rerank = self.query_instruction_format_for_rerank.replace("\\n", "\n")
if "\\n" in self.passage_instruction_format_for_rerank:
self.passage_instruction_format_for_rerank = self.passage_instruction_format_for_rerank.replace("\\n", "\n")

View File

@ -114,6 +114,13 @@ class AbsEmbedderDataArguments:
)
def __post_init__(self):
# replace "\\n" with "\n"
if "\\n" in self.query_instruction_format:
self.query_instruction_format = self.query_instruction_format.replace("\\n", "\n")
if "\\n" in self.passage_instruction_format:
self.passage_instruction_format = self.passage_instruction_format.replace("\\n", "\n")
# check the existence of train data
for train_dir in self.train_data:
if not os.path.exists(train_dir):
raise FileNotFoundError(f"cannot find file: {train_dir}, please set a true path")

View File

@ -119,10 +119,17 @@ class AbsRerankerDataArguments:
default='\n', metadata={"help": "The sep token for LLM reranker to discriminate between query and passage"}
)
# def __post_init__(self):
# for train_dir in self.train_data:
# if not os.path.exists(train_dir):
# raise FileNotFoundError(f"cannot find file: {train_dir}, please set a true path")
def __post_init__(self):
# replace "\\n" with "\n"
if "\\n" in self.query_instruction_format:
self.query_instruction_format = self.query_instruction_format.replace("\\n", "\n")
if "\\n" in self.passage_instruction_format:
self.passage_instruction_format = self.passage_instruction_format.replace("\\n", "\n")
# check the existence of train data
for train_dir in self.train_data:
if not os.path.exists(train_dir):
raise FileNotFoundError(f"cannot find file: {train_dir}, please set a true path")
@dataclass

View File

@ -59,7 +59,6 @@ class AbsEmbedder(ABC):
convert_to_numpy: bool = True,
**kwargs: Any,
):
query_instruction_format = query_instruction_format.replace('\\n', '\n')
self.model_name_or_path = model_name_or_path
self.normalize_embeddings = normalize_embeddings
self.use_fp16 = use_fp16
@ -152,6 +151,8 @@ class AbsEmbedder(ABC):
Returns:
str: The complete sentence with instruction
"""
if "\\n" in instruction_format:
instruction_format = instruction_format.replace("\\n", "\n")
return instruction_format.format(instruction, sentence)
def encode_queries(

View File

@ -149,6 +149,8 @@ class AbsReranker(ABC):
Returns:
str: The complete sentence with instruction
"""
if "\\n" in instruction_format:
instruction_format = instruction_format.replace("\\n", "\n")
return instruction_format.format(instruction, sentence)
def get_detailed_inputs(self, sentence_pairs: Union[str, List[str]]):

View File

@ -102,3 +102,14 @@ class AIRBenchEvalModelArgs:
compress_layers: Optional[int] = field(
default=None, metadata={"help": "The compress layers of lightweight reranker.", "nargs": "+"}
)
def __post_init__(self):
# replace "\\n" with "\n"
if "\\n" in self.query_instruction_format_for_retrieval:
self.query_instruction_format_for_retrieval = self.query_instruction_format_for_retrieval.replace("\\n", "\n")
if "\\n" in self.examples_instruction_format:
self.examples_instruction_format = self.examples_instruction_format.replace("\\n", "\n")
if "\\n" in self.query_instruction_format_for_rerank:
self.query_instruction_format_for_rerank = self.query_instruction_format_for_rerank.replace("\\n", "\n")
if "\\n" in self.passage_instruction_format_for_rerank:
self.passage_instruction_format_for_rerank = self.passage_instruction_format_for_rerank.replace("\\n", "\n")

View File

@ -172,6 +172,8 @@ class ICLLLMEmbedder(AbsEmbedder):
Returns:
str: The complete example following the given format.
"""
if "\\n" in instruction_format:
instruction_format = instruction_format.replace("\\n", "\n")
return instruction_format.format(instruction, query, response)
def stop_self_query_pool(self):

View File

@ -90,6 +90,13 @@ class ModelArgs:
embedder_passage_max_length: int = field(
default=512, metadata={"help": "Max length for passage."}
)
def __post_init__(self):
# replace "\\n" with "\n"
if "\\n" in self.query_instruction_format_for_retrieval:
self.query_instruction_format_for_retrieval = self.query_instruction_format_for_retrieval.replace("\\n", "\n")
if "\\n" in self.examples_instruction_format:
self.examples_instruction_format = self.examples_instruction_format.replace("\\n", "\n")
def create_index(embeddings: np.ndarray, use_gpu: bool = False):