mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 02:39:51 +00:00
Validate max_seq_length in SquadProcessor (#2740)
* added max_len_seq validation in SquadProcessor * fixed string formatting * added tests for invalid max_seq_len * Update Documentation & Code Style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
ffb7e4e4bd
commit
31dcd55c24
@ -475,10 +475,11 @@ class DataSilo:
|
||||
logger.info("Proportion clipped: {}".format(clipped))
|
||||
if clipped > 0.5:
|
||||
logger.info(
|
||||
"[Haystack Tip] {}% of your samples got cut down to {} tokens. "
|
||||
"Consider increasing max_seq_len. "
|
||||
"This will lead to higher memory consumption but is likely to "
|
||||
"improve your model performance".format(round(clipped * 100, 1), max_seq_len)
|
||||
f"[Haystack Tip] {round(clipped * 100, 1)}% of your samples got cut down to {max_seq_len} tokens. "
|
||||
"Consider increasing max_seq_len "
|
||||
f"(the maximum value allowed with the current model is max_seq_len={self.processor.tokenizer.model_max_length}, "
|
||||
"if this is not enough consider splitting the document in smaller units or changing the model). "
|
||||
"This will lead to higher memory consumption but is likely to improve your model performance"
|
||||
)
|
||||
elif "query_input_ids" in self.tensor_names and "passage_input_ids" in self.tensor_names:
|
||||
logger.info(
|
||||
|
||||
@ -408,6 +408,13 @@ class SquadProcessor(Processor):
|
||||
"""
|
||||
self.ph_output_type = "per_token_squad"
|
||||
|
||||
# validate max_seq_len
|
||||
assert max_seq_len <= tokenizer.model_max_length, (
|
||||
"max_seq_len cannot be greater than the maximum sequence length handled by the model: "
|
||||
f"got max_seq_len={max_seq_len}, while the model maximum length is {tokenizer.model_max_length}. "
|
||||
"Please adjust max_seq_len accordingly or use another model "
|
||||
)
|
||||
|
||||
assert doc_stride < (max_seq_len - max_query_length), (
|
||||
"doc_stride ({}) is longer than max_seq_len ({}) minus space reserved for query tokens ({}). \nThis means that there will be gaps "
|
||||
"as the passage windows slide, causing the model to skip over parts of the document.\n"
|
||||
@ -490,6 +497,13 @@ class SquadProcessor(Processor):
|
||||
["text", "questions"] (api format). This function converts the latter into the former. It also converts the
|
||||
is_impossible field to answer_type so that NQ and SQuAD dicts have the same format.
|
||||
"""
|
||||
# validate again max_seq_len
|
||||
assert self.max_seq_len <= self.tokenizer.model_max_length, (
|
||||
"max_seq_len cannot be greater than the maximum sequence length handled by the model: "
|
||||
f"got max_seq_len={self.max_seq_len}, while the model maximum length is {self.tokenizer.model_max_length}. "
|
||||
"Please adjust max_seq_len accordingly or use another model "
|
||||
)
|
||||
|
||||
# check again for doc stride vs max_seq_len when. Parameters can be changed for already initialized models (e.g. in haystack)
|
||||
assert self.doc_stride < (self.max_seq_len - self.max_query_length), (
|
||||
"doc_stride ({}) is longer than max_seq_len ({}) minus space reserved for query tokens ({}). \nThis means that there will be gaps "
|
||||
|
||||
@ -177,6 +177,22 @@ def test_top_k(reader, docs, top_k):
|
||||
print("WARNING: Could not set `top_k_per_sample` in FARM. Please update FARM version.")
|
||||
|
||||
|
||||
def test_farm_reader_invalid_params():
|
||||
# invalid max_seq_len (greater than model maximum seq length)
|
||||
with pytest.raises(Exception):
|
||||
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", use_gpu=False, max_seq_len=513)
|
||||
|
||||
# invalid max_seq_len (max_seq_len >= doc_stride)
|
||||
with pytest.raises(Exception):
|
||||
reader = FARMReader(
|
||||
model_name_or_path="deepset/roberta-base-squad2", use_gpu=False, max_seq_len=129, doc_stride=128
|
||||
)
|
||||
|
||||
# invalid doc_stride (doc_stride >= (max_seq_len - max_query_length))
|
||||
with pytest.raises(Exception):
|
||||
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", use_gpu=False, doc_stride=999)
|
||||
|
||||
|
||||
def test_farm_reader_update_params(docs):
|
||||
reader = FARMReader(
|
||||
model_name_or_path="deepset/roberta-base-squad2", use_gpu=False, no_ans_boost=0, num_processes=0
|
||||
@ -219,6 +235,14 @@ def test_farm_reader_update_params(docs):
|
||||
reader.update_parameters(context_window_size=6, no_ans_boost=-10, max_seq_len=99, doc_stride=128)
|
||||
reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
|
||||
|
||||
# update max_seq_len with invalid value (greater than the model maximum sequence length)
|
||||
with pytest.raises(Exception):
|
||||
invalid_max_seq_len = reader.inferencer.processor.tokenizer.model_max_length + 1
|
||||
reader.update_parameters(
|
||||
context_window_size=100, no_ans_boost=-10, max_seq_len=invalid_max_seq_len, doc_stride=128
|
||||
)
|
||||
reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_confidence_scores", [True, False])
|
||||
def test_farm_reader_uses_same_sorting_as_QAPredictionHead(use_confidence_scores):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user