mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-06-27 02:39:58 +00:00
fix "\\n" issue: replace "\\n" with "\n"
This commit is contained in:
parent
317aee47bf
commit
1bfdf2cbb0
@ -177,3 +177,14 @@ class AbsEvalModelArgs:
|
|||||||
compress_layers: Optional[int] = field(
|
compress_layers: Optional[int] = field(
|
||||||
default=None, metadata={"help": "The compress layers of lightweight reranker.", "nargs": "+"}
|
default=None, metadata={"help": "The compress layers of lightweight reranker.", "nargs": "+"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# replace "\\n" with "\n"
|
||||||
|
if "\\n" in self.query_instruction_format_for_retrieval:
|
||||||
|
self.query_instruction_format_for_retrieval = self.query_instruction_format_for_retrieval.replace("\\n", "\n")
|
||||||
|
if "\\n" in self.examples_instruction_format:
|
||||||
|
self.examples_instruction_format = self.examples_instruction_format.replace("\\n", "\n")
|
||||||
|
if "\\n" in self.query_instruction_format_for_rerank:
|
||||||
|
self.query_instruction_format_for_rerank = self.query_instruction_format_for_rerank.replace("\\n", "\n")
|
||||||
|
if "\\n" in self.passage_instruction_format_for_rerank:
|
||||||
|
self.passage_instruction_format_for_rerank = self.passage_instruction_format_for_rerank.replace("\\n", "\n")
|
||||||
|
@ -114,6 +114,13 @@ class AbsEmbedderDataArguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
# replace "\\n" with "\n"
|
||||||
|
if "\\n" in self.query_instruction_format:
|
||||||
|
self.query_instruction_format = self.query_instruction_format.replace("\\n", "\n")
|
||||||
|
if "\\n" in self.passage_instruction_format:
|
||||||
|
self.passage_instruction_format = self.passage_instruction_format.replace("\\n", "\n")
|
||||||
|
|
||||||
|
# check the existence of train data
|
||||||
for train_dir in self.train_data:
|
for train_dir in self.train_data:
|
||||||
if not os.path.exists(train_dir):
|
if not os.path.exists(train_dir):
|
||||||
raise FileNotFoundError(f"cannot find file: {train_dir}, please set a true path")
|
raise FileNotFoundError(f"cannot find file: {train_dir}, please set a true path")
|
||||||
|
@ -119,10 +119,17 @@ class AbsRerankerDataArguments:
|
|||||||
default='\n', metadata={"help": "The sep token for LLM reranker to discriminate between query and passage"}
|
default='\n', metadata={"help": "The sep token for LLM reranker to discriminate between query and passage"}
|
||||||
)
|
)
|
||||||
|
|
||||||
# def __post_init__(self):
|
def __post_init__(self):
|
||||||
# for train_dir in self.train_data:
|
# replace "\\n" with "\n"
|
||||||
# if not os.path.exists(train_dir):
|
if "\\n" in self.query_instruction_format:
|
||||||
# raise FileNotFoundError(f"cannot find file: {train_dir}, please set a true path")
|
self.query_instruction_format = self.query_instruction_format.replace("\\n", "\n")
|
||||||
|
if "\\n" in self.passage_instruction_format:
|
||||||
|
self.passage_instruction_format = self.passage_instruction_format.replace("\\n", "\n")
|
||||||
|
|
||||||
|
# check the existence of train data
|
||||||
|
for train_dir in self.train_data:
|
||||||
|
if not os.path.exists(train_dir):
|
||||||
|
raise FileNotFoundError(f"cannot find file: {train_dir}, please set a true path")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -59,7 +59,6 @@ class AbsEmbedder(ABC):
|
|||||||
convert_to_numpy: bool = True,
|
convert_to_numpy: bool = True,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
query_instruction_format = query_instruction_format.replace('\\n', '\n')
|
|
||||||
self.model_name_or_path = model_name_or_path
|
self.model_name_or_path = model_name_or_path
|
||||||
self.normalize_embeddings = normalize_embeddings
|
self.normalize_embeddings = normalize_embeddings
|
||||||
self.use_fp16 = use_fp16
|
self.use_fp16 = use_fp16
|
||||||
@ -152,6 +151,8 @@ class AbsEmbedder(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
str: The complete sentence with instruction
|
str: The complete sentence with instruction
|
||||||
"""
|
"""
|
||||||
|
if "\\n" in instruction_format:
|
||||||
|
instruction_format = instruction_format.replace("\\n", "\n")
|
||||||
return instruction_format.format(instruction, sentence)
|
return instruction_format.format(instruction, sentence)
|
||||||
|
|
||||||
def encode_queries(
|
def encode_queries(
|
||||||
|
@ -149,6 +149,8 @@ class AbsReranker(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
str: The complete sentence with instruction
|
str: The complete sentence with instruction
|
||||||
"""
|
"""
|
||||||
|
if "\\n" in instruction_format:
|
||||||
|
instruction_format = instruction_format.replace("\\n", "\n")
|
||||||
return instruction_format.format(instruction, sentence)
|
return instruction_format.format(instruction, sentence)
|
||||||
|
|
||||||
def get_detailed_inputs(self, sentence_pairs: Union[str, List[str]]):
|
def get_detailed_inputs(self, sentence_pairs: Union[str, List[str]]):
|
||||||
|
@ -102,3 +102,14 @@ class AIRBenchEvalModelArgs:
|
|||||||
compress_layers: Optional[int] = field(
|
compress_layers: Optional[int] = field(
|
||||||
default=None, metadata={"help": "The compress layers of lightweight reranker.", "nargs": "+"}
|
default=None, metadata={"help": "The compress layers of lightweight reranker.", "nargs": "+"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# replace "\\n" with "\n"
|
||||||
|
if "\\n" in self.query_instruction_format_for_retrieval:
|
||||||
|
self.query_instruction_format_for_retrieval = self.query_instruction_format_for_retrieval.replace("\\n", "\n")
|
||||||
|
if "\\n" in self.examples_instruction_format:
|
||||||
|
self.examples_instruction_format = self.examples_instruction_format.replace("\\n", "\n")
|
||||||
|
if "\\n" in self.query_instruction_format_for_rerank:
|
||||||
|
self.query_instruction_format_for_rerank = self.query_instruction_format_for_rerank.replace("\\n", "\n")
|
||||||
|
if "\\n" in self.passage_instruction_format_for_rerank:
|
||||||
|
self.passage_instruction_format_for_rerank = self.passage_instruction_format_for_rerank.replace("\\n", "\n")
|
||||||
|
@ -172,6 +172,8 @@ class ICLLLMEmbedder(AbsEmbedder):
|
|||||||
Returns:
|
Returns:
|
||||||
str: The complete example following the given format.
|
str: The complete example following the given format.
|
||||||
"""
|
"""
|
||||||
|
if "\\n" in instruction_format:
|
||||||
|
instruction_format = instruction_format.replace("\\n", "\n")
|
||||||
return instruction_format.format(instruction, query, response)
|
return instruction_format.format(instruction, query, response)
|
||||||
|
|
||||||
def stop_self_query_pool(self):
|
def stop_self_query_pool(self):
|
||||||
|
@ -90,6 +90,13 @@ class ModelArgs:
|
|||||||
embedder_passage_max_length: int = field(
|
embedder_passage_max_length: int = field(
|
||||||
default=512, metadata={"help": "Max length for passage."}
|
default=512, metadata={"help": "Max length for passage."}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# replace "\\n" with "\n"
|
||||||
|
if "\\n" in self.query_instruction_format_for_retrieval:
|
||||||
|
self.query_instruction_format_for_retrieval = self.query_instruction_format_for_retrieval.replace("\\n", "\n")
|
||||||
|
if "\\n" in self.examples_instruction_format:
|
||||||
|
self.examples_instruction_format = self.examples_instruction_format.replace("\\n", "\n")
|
||||||
|
|
||||||
|
|
||||||
def create_index(embeddings: np.ndarray, use_gpu: bool = False):
|
def create_index(embeddings: np.ndarray, use_gpu: bool = False):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user