mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
fix "\\n" issue: replace "\\n" with "\n"
This commit is contained in:
parent
317aee47bf
commit
1bfdf2cbb0
@ -177,3 +177,14 @@ class AbsEvalModelArgs:
|
||||
compress_layers: Optional[int] = field(
|
||||
default=None, metadata={"help": "The compress layers of lightweight reranker.", "nargs": "+"}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# replace "\\n" with "\n"
|
||||
if "\\n" in self.query_instruction_format_for_retrieval:
|
||||
self.query_instruction_format_for_retrieval = self.query_instruction_format_for_retrieval.replace("\\n", "\n")
|
||||
if "\\n" in self.examples_instruction_format:
|
||||
self.examples_instruction_format = self.examples_instruction_format.replace("\\n", "\n")
|
||||
if "\\n" in self.query_instruction_format_for_rerank:
|
||||
self.query_instruction_format_for_rerank = self.query_instruction_format_for_rerank.replace("\\n", "\n")
|
||||
if "\\n" in self.passage_instruction_format_for_rerank:
|
||||
self.passage_instruction_format_for_rerank = self.passage_instruction_format_for_rerank.replace("\\n", "\n")
|
||||
|
@ -114,6 +114,13 @@ class AbsEmbedderDataArguments:
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# replace "\\n" with "\n"
|
||||
if "\\n" in self.query_instruction_format:
|
||||
self.query_instruction_format = self.query_instruction_format.replace("\\n", "\n")
|
||||
if "\\n" in self.passage_instruction_format:
|
||||
self.passage_instruction_format = self.passage_instruction_format.replace("\\n", "\n")
|
||||
|
||||
# check the existence of train data
|
||||
for train_dir in self.train_data:
|
||||
if not os.path.exists(train_dir):
|
||||
raise FileNotFoundError(f"cannot find file: {train_dir}, please set a true path")
|
||||
|
@ -119,10 +119,17 @@ class AbsRerankerDataArguments:
|
||||
default='\n', metadata={"help": "The sep token for LLM reranker to discriminate between query and passage"}
|
||||
)
|
||||
|
||||
# def __post_init__(self):
|
||||
# for train_dir in self.train_data:
|
||||
# if not os.path.exists(train_dir):
|
||||
# raise FileNotFoundError(f"cannot find file: {train_dir}, please set a true path")
|
||||
def __post_init__(self):
|
||||
# replace "\\n" with "\n"
|
||||
if "\\n" in self.query_instruction_format:
|
||||
self.query_instruction_format = self.query_instruction_format.replace("\\n", "\n")
|
||||
if "\\n" in self.passage_instruction_format:
|
||||
self.passage_instruction_format = self.passage_instruction_format.replace("\\n", "\n")
|
||||
|
||||
# check the existence of train data
|
||||
for train_dir in self.train_data:
|
||||
if not os.path.exists(train_dir):
|
||||
raise FileNotFoundError(f"cannot find file: {train_dir}, please set a true path")
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -59,7 +59,6 @@ class AbsEmbedder(ABC):
|
||||
convert_to_numpy: bool = True,
|
||||
**kwargs: Any,
|
||||
):
|
||||
query_instruction_format = query_instruction_format.replace('\\n', '\n')
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.normalize_embeddings = normalize_embeddings
|
||||
self.use_fp16 = use_fp16
|
||||
@ -152,6 +151,8 @@ class AbsEmbedder(ABC):
|
||||
Returns:
|
||||
str: The complete sentence with instruction
|
||||
"""
|
||||
if "\\n" in instruction_format:
|
||||
instruction_format = instruction_format.replace("\\n", "\n")
|
||||
return instruction_format.format(instruction, sentence)
|
||||
|
||||
def encode_queries(
|
||||
|
@ -149,6 +149,8 @@ class AbsReranker(ABC):
|
||||
Returns:
|
||||
str: The complete sentence with instruction
|
||||
"""
|
||||
if "\\n" in instruction_format:
|
||||
instruction_format = instruction_format.replace("\\n", "\n")
|
||||
return instruction_format.format(instruction, sentence)
|
||||
|
||||
def get_detailed_inputs(self, sentence_pairs: Union[str, List[str]]):
|
||||
|
@ -102,3 +102,14 @@ class AIRBenchEvalModelArgs:
|
||||
compress_layers: Optional[int] = field(
|
||||
default=None, metadata={"help": "The compress layers of lightweight reranker.", "nargs": "+"}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# replace "\\n" with "\n"
|
||||
if "\\n" in self.query_instruction_format_for_retrieval:
|
||||
self.query_instruction_format_for_retrieval = self.query_instruction_format_for_retrieval.replace("\\n", "\n")
|
||||
if "\\n" in self.examples_instruction_format:
|
||||
self.examples_instruction_format = self.examples_instruction_format.replace("\\n", "\n")
|
||||
if "\\n" in self.query_instruction_format_for_rerank:
|
||||
self.query_instruction_format_for_rerank = self.query_instruction_format_for_rerank.replace("\\n", "\n")
|
||||
if "\\n" in self.passage_instruction_format_for_rerank:
|
||||
self.passage_instruction_format_for_rerank = self.passage_instruction_format_for_rerank.replace("\\n", "\n")
|
||||
|
@ -172,6 +172,8 @@ class ICLLLMEmbedder(AbsEmbedder):
|
||||
Returns:
|
||||
str: The complete example following the given format.
|
||||
"""
|
||||
if "\\n" in instruction_format:
|
||||
instruction_format = instruction_format.replace("\\n", "\n")
|
||||
return instruction_format.format(instruction, query, response)
|
||||
|
||||
def stop_self_query_pool(self):
|
||||
|
@ -90,6 +90,13 @@ class ModelArgs:
|
||||
embedder_passage_max_length: int = field(
|
||||
default=512, metadata={"help": "Max length for passage."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# replace "\\n" with "\n"
|
||||
if "\\n" in self.query_instruction_format_for_retrieval:
|
||||
self.query_instruction_format_for_retrieval = self.query_instruction_format_for_retrieval.replace("\\n", "\n")
|
||||
if "\\n" in self.examples_instruction_format:
|
||||
self.examples_instruction_format = self.examples_instruction_format.replace("\\n", "\n")
|
||||
|
||||
|
||||
def create_index(embeddings: np.ndarray, use_gpu: bool = False):
|
||||
|
Loading…
x
Reference in New Issue
Block a user