| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  | # Copyright 2019 Google LLC | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | # you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | # You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #     https://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | # distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | # See the License for the specific language governing permissions and | 
					
						
							|  |  |  | # limitations under the License. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | r"""
 | 
					
						
							|  |  |  | DEEPSET DOCSTRING: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | A modified version of the script from here: | 
					
						
							| 
									
										
										
										
											2022-08-24 19:05:12 +02:00
										 |  |  | https://github.com/google/retrieval-qa-eval/blob/main/nq_to_squad.py | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  | Edits have been made by deepset in order to create a dev set for Haystack benchmarking. | 
					
						
							|  |  |  | Input should be the official NQ dev set (v1.0-simplified-nq-dev-all.jsonl.gz) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Expected numbers are: | 
					
						
							|  |  |  | Converted 7830 NQ records into 5678 SQuAD records. | 
					
						
							|  |  |  | Removed samples: yes/no: 177 multi_short: 648 non_para 1192 long_ans_only: 130 errors: 5 | 
					
						
							|  |  |  | Removed annotations: long_answer: 4610 short_answer: 953 no_answer: ~1006 | 
					
						
							|  |  |  | where: | 
					
						
							|  |  |  | multi_short - annotations where there are multiple disjoint short answers | 
					
						
							|  |  |  | non_para - where the annotation occurs in an html element that is not a paragraph | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ORIGINAL DOCSTRING: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Convert the Natural Questions dataset into SQuAD JSON format. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | To use this utility, first follow the directions at the URL below to download | 
					
						
							|  |  |  | the complete training dataset. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     https://ai.google.com/research/NaturalQuestions/download | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Next, run this program, specifying the data you wish to convert. For instance, | 
					
						
							|  |  |  | the invocation: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     python nq_to_squad.py\ | 
					
						
							|  |  |  |         --data_pattern=/usr/local/data/tnq/v1.0/train/*.gz\ | 
					
						
							|  |  |  |         --output_file=/usr/local/data/tnq/v1.0/train.json | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | will process all training data and write the results into `train.json`. This | 
					
						
							|  |  |  | file can, in turn, be provided to squad_eval.py using the --squad argument. | 
					
						
							|  |  |  | """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import argparse | 
					
						
							|  |  |  | import glob | 
					
						
							|  |  |  | import gzip | 
					
						
							|  |  |  | import json | 
					
						
							|  |  |  | import logging | 
					
						
							|  |  |  | import os | 
					
						
							|  |  |  | import re | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # Dropped samples | 
					
						
							|  |  |  | n_yn = 0 | 
					
						
							|  |  |  | n_ms = 0 | 
					
						
							|  |  |  | n_non_p = 0 | 
					
						
							|  |  |  | n_long_ans_only = 0 | 
					
						
							|  |  |  | n_error = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # Dropped annotations | 
					
						
							|  |  |  | n_long_ans = 0 | 
					
						
							|  |  |  | n_no_ans = 0 | 
					
						
							|  |  |  | n_short = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | def clean_text(start_token, end_token, doc_tokens, doc_bytes, ignore_final_whitespace=True): | 
					
						
							|  |  |  |     """Remove HTML tags from a text span and reconstruct proper spacing.""" | 
					
						
							|  |  |  |     text = "" | 
					
						
							|  |  |  |     for index in range(start_token, end_token): | 
					
						
							|  |  |  |         token = doc_tokens[index] | 
					
						
							|  |  |  |         if token["html_token"]: | 
					
						
							|  |  |  |             continue | 
					
						
							|  |  |  |         text += token["token"] | 
					
						
							|  |  |  |         # Add a single space between two tokens iff there is at least one | 
					
						
							|  |  |  |         # whitespace character between them (outside of an HTML tag). For example: | 
					
						
							|  |  |  |         # | 
					
						
							|  |  |  |         #   token1 token2                           ==> Add space. | 
					
						
							|  |  |  |         #   token1</B> <B>token2                    ==> Add space. | 
					
						
							|  |  |  |         #   token1</A>token2                        ==> No space. | 
					
						
							|  |  |  |         #   token1<A href="..." title="...">token2  ==> No space. | 
					
						
							|  |  |  |         #   token1<SUP>2</SUP>token2                ==> No space. | 
					
						
							|  |  |  |         next_token = token | 
					
						
							|  |  |  |         last_index = end_token if ignore_final_whitespace else end_token + 1 | 
					
						
							|  |  |  |         for next_token in doc_tokens[index + 1 : last_index]: | 
					
						
							|  |  |  |             if not next_token["html_token"]: | 
					
						
							|  |  |  |                 break | 
					
						
							|  |  |  |         chars = doc_bytes[token["end_byte"] : next_token["start_byte"]].decode("utf-8") | 
					
						
							|  |  |  |         # Since some HTML tags are missing from the token list, we count '<' and | 
					
						
							|  |  |  |         # '>' to detect if we're inside a tag. | 
					
						
							|  |  |  |         unclosed_brackets = 0 | 
					
						
							|  |  |  |         for char in chars: | 
					
						
							|  |  |  |             if char == "<": | 
					
						
							|  |  |  |                 unclosed_brackets += 1 | 
					
						
							|  |  |  |             elif char == ">": | 
					
						
							|  |  |  |                 unclosed_brackets -= 1 | 
					
						
							|  |  |  |             elif unclosed_brackets == 0 and re.match(r"\s", char): | 
					
						
							|  |  |  |                 # Add a single space after this token. | 
					
						
							|  |  |  |                 text += " " | 
					
						
							|  |  |  |                 break | 
					
						
							|  |  |  |     return text | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | def get_anno_type(annotation): | 
					
						
							|  |  |  |     long_answer = annotation["long_answer"] | 
					
						
							|  |  |  |     short_answers = annotation["short_answers"] | 
					
						
							|  |  |  |     yes_no_answer = annotation["yes_no_answer"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if len(short_answers) > 1: | 
					
						
							|  |  |  |         return "multi_short" | 
					
						
							|  |  |  |     elif yes_no_answer != "NONE": | 
					
						
							|  |  |  |         return yes_no_answer | 
					
						
							|  |  |  |     elif len(short_answers) == 1: | 
					
						
							|  |  |  |         return "short_answer" | 
					
						
							|  |  |  |     elif len(short_answers) == 0: | 
					
						
							|  |  |  |         if long_answer["start_token"] == -1: | 
					
						
							|  |  |  |             return "no_answer" | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             return "long_answer" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  | def reduce_annotations(anno_types, answers): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     In cases where there is annotator disagreement, this fn picks either only the short_answers or only the no_answers, | 
					
						
							|  |  |  |     depending on which is more numerous, with a bias towards picking short_answers. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Note: By this stage, all long_answer annotations and all samples with yes/no answer have been removed. | 
					
						
							|  |  |  |     This leaves just no_answer and short_answers"""
 | 
					
						
							|  |  |  |     for at in set(anno_types): | 
					
						
							|  |  |  |         assert at in ("no_answer", "short_answer") | 
					
						
							|  |  |  |     if anno_types.count("short_answer") >= anno_types.count("no_answer"): | 
					
						
							|  |  |  |         majority = "short_answer" | 
					
						
							|  |  |  |         is_impossible = False | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         majority = "no_answer" | 
					
						
							|  |  |  |         is_impossible = True | 
					
						
							|  |  |  |     answers = [a for at, a in zip(anno_types, answers) if at == majority] | 
					
						
							|  |  |  |     reduction = len(anno_types) - len(answers) | 
					
						
							|  |  |  |     assert reduction < 3 | 
					
						
							|  |  |  |     if not is_impossible: | 
					
						
							|  |  |  |         global n_no_ans | 
					
						
							|  |  |  |         n_no_ans += reduction | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         global n_short | 
					
						
							|  |  |  |         n_short += reduction | 
					
						
							|  |  |  |         answers = [] | 
					
						
							|  |  |  |     return answers, is_impossible | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def nq_to_squad(record): | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     """Convert a Natural Questions record to SQuAD format.""" | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     doc_bytes = record["document_html"].encode("utf-8") | 
					
						
							|  |  |  |     doc_tokens = record["document_tokens"] | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     question_text = record["question_text"] | 
					
						
							|  |  |  |     question_text = question_text[0].upper() + question_text[1:] + "?" | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     answers = [] | 
					
						
							|  |  |  |     anno_types = [] | 
					
						
							|  |  |  |     for annotation in record["annotations"]: | 
					
						
							|  |  |  |         anno_type = get_anno_type(annotation) | 
					
						
							|  |  |  |         long_answer = annotation["long_answer"] | 
					
						
							|  |  |  |         short_answers = annotation["short_answers"] | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |         if anno_type.lower() in ["yes", "no"]: | 
					
						
							|  |  |  |             global n_yn | 
					
						
							|  |  |  |             n_yn += 1 | 
					
						
							|  |  |  |             return | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |         # Skip examples that don't have exactly one short answer. | 
					
						
							|  |  |  |         # Note: Consider including multi-span short answers. | 
					
						
							|  |  |  |         if anno_type == "multi_short": | 
					
						
							|  |  |  |             global n_ms | 
					
						
							|  |  |  |             n_ms += 1 | 
					
						
							|  |  |  |             return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         elif anno_type == "short_answer": | 
					
						
							|  |  |  |             short_answer = short_answers[0] | 
					
						
							|  |  |  |             # Skip examples corresponding to HTML blocks other than <P>. | 
					
						
							|  |  |  |             long_answer_html_tag = doc_tokens[long_answer["start_token"]]["token"] | 
					
						
							|  |  |  |             if long_answer_html_tag != "<P>": | 
					
						
							|  |  |  |                 global n_non_p | 
					
						
							|  |  |  |                 n_non_p += 1 | 
					
						
							|  |  |  |                 return | 
					
						
							|  |  |  |             answer = clean_text(short_answer["start_token"], short_answer["end_token"], doc_tokens, doc_bytes) | 
					
						
							|  |  |  |             before_answer = clean_text( | 
					
						
							|  |  |  |                 0, short_answer["start_token"], doc_tokens, doc_bytes, ignore_final_whitespace=False | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         elif anno_type == "no_answer": | 
					
						
							|  |  |  |             answer = "" | 
					
						
							|  |  |  |             before_answer = "" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Throw out long answer annotations | 
					
						
							|  |  |  |         elif anno_type == "long_answer": | 
					
						
							|  |  |  |             global n_long_ans | 
					
						
							|  |  |  |             n_long_ans += 1 | 
					
						
							|  |  |  |             continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         anno_types.append(anno_type) | 
					
						
							|  |  |  |         answer = {"answer_start": len(before_answer), "text": answer} | 
					
						
							|  |  |  |         answers.append(answer) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if len(answers) == 0: | 
					
						
							|  |  |  |         global n_long_ans_only | 
					
						
							|  |  |  |         n_long_ans_only += 1 | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  |         return | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     answers, is_impossible = reduce_annotations(anno_types, answers) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     paragraph = clean_text(0, len(doc_tokens), doc_tokens, doc_bytes) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return { | 
					
						
							|  |  |  |         "title": record["document_title"], | 
					
						
							|  |  |  |         "paragraphs": [ | 
					
						
							|  |  |  |             { | 
					
						
							|  |  |  |                 "context": paragraph, | 
					
						
							|  |  |  |                 "qas": [ | 
					
						
							|  |  |  |                     { | 
					
						
							|  |  |  |                         "answers": answers, | 
					
						
							|  |  |  |                         "id": record["example_id"], | 
					
						
							|  |  |  |                         "question": question_text, | 
					
						
							|  |  |  |                         "is_impossible": is_impossible, | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 ], | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         ], | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | def main(): | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     parser = argparse.ArgumentParser(description="Convert the Natural Questions to SQuAD JSON format.") | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--data_pattern", | 
					
						
							|  |  |  |         dest="data_pattern", | 
					
						
							|  |  |  |         help=("A file pattern to match the Natural Questions " "dataset."), | 
					
						
							|  |  |  |         metavar="PATTERN", | 
					
						
							|  |  |  |         required=True, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--version", dest="version", help="The version label in the output file.", metavar="LABEL", default="nq-train" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--output_file", | 
					
						
							|  |  |  |         dest="output_file", | 
					
						
							|  |  |  |         help="The name of the SQuAD JSON formatted output file.", | 
					
						
							|  |  |  |         metavar="FILE", | 
					
						
							|  |  |  |         default="nq_as_squad.json", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     root = logging.getLogger() | 
					
						
							|  |  |  |     root.setLevel(logging.DEBUG) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     records = 0 | 
					
						
							|  |  |  |     nq_as_squad = {"version": args.version, "data": []} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for file in sorted(glob.iglob(args.data_pattern)): | 
					
						
							|  |  |  |         logging.info("opening %s", file) | 
					
						
							|  |  |  |         with gzip.GzipFile(file, "r") as f: | 
					
						
							|  |  |  |             for line in f: | 
					
						
							|  |  |  |                 records += 1 | 
					
						
							|  |  |  |                 nq_record = json.loads(line) | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     squad_record = nq_to_squad(nq_record) | 
					
						
							|  |  |  |                 except: | 
					
						
							|  |  |  |                     squad_record = None | 
					
						
							|  |  |  |                     global n_error | 
					
						
							|  |  |  |                     n_error += 1 | 
					
						
							|  |  |  |                 if squad_record: | 
					
						
							|  |  |  |                     nq_as_squad["data"].append(squad_record) | 
					
						
							|  |  |  |                 if records % 100 == 0: | 
					
						
							|  |  |  |                     logging.info("processed %s records", records) | 
					
						
							|  |  |  |     print("Converted %s NQ records into %s SQuAD records." % (records, len(nq_as_squad["data"]))) | 
					
						
							|  |  |  |     print( | 
					
						
							|  |  |  |         f"Removed samples: yes/no: {n_yn} multi_short: {n_ms} non_para {n_non_p} long_ans_only: {n_long_ans_only} errors: {n_error}" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     print(f"Removed annotations: long_answer: {n_long_ans} short_answer: {n_short} no_answer: ~{n_no_ans}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with open(args.output_file, "w") as f: | 
					
						
							|  |  |  |         json.dump(nq_as_squad, f, indent=4) | 
					
						
							| 
									
										
										
										
											2020-10-12 13:34:42 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     main() |