haystack/test/benchmarks/nq_to_squad.py
Sara Zan a59bca3661
Apply black formatting (#2115)
* Testing black on ui/

* Applying black on docstores

* Add latest docstring and tutorial changes

* Create a single GH action for Black and docs to reduce commit noise to the minimum, slightly refactor the OpenAPI action too

* Remove comments

* Relax constraints on pydoc-markdown

* Split temporary black from the docs. Pydoc-markdown was obsolete and needs a separate PR to upgrade

* Fix a couple of bugs

* Add a type: ignore that was missing somehow

* Give path to black

* Apply Black

* Apply Black

* Relocate a couple of type: ignore

* Update documentation

* Make Linux CI run after applying Black

* Triggering Black

* Apply Black

* Remove dependency, does not work well

* Remove manually double trailing commas

* Update documentation

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2022-02-03 13:43:18 +01:00

294 lines
9.8 KiB
Python

#!/usr/bin/python3
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
DEEPSET DOCSTRING:
A modified version of the script from here:
https://github.com/google/retrieval-qa-eval/blob/master/nq_to_squad.py
Edits have been made by deepset in order to create a dev set for Haystack benchmarking.
Input should be the official NQ dev set (v1.0-simplified-nq-dev-all.jsonl.gz)
Expected numbers are:
Converted 7830 NQ records into 5678 SQuAD records.
Removed samples: yes/no: 177 multi_short: 648 non_para 1192 long_ans_only: 130 errors: 5
Removed annotations: long_answer: 4610 short_answer: 953 no_answer: ~1006
where:
multi_short - annotations where there are multiple disjoint short answers
non_para - where the annotation occurs in an html element that is not a paragraph
ORIGINAL DOCSTRING:
Convert the Natural Questions dataset into SQuAD JSON format.
To use this utility, first follow the directions at the URL below to download
the complete training dataset.
https://ai.google.com/research/NaturalQuestions/download
Next, run this program, specifying the data you wish to convert. For instance,
the invocation:
python nq_to_squad.py\
--data_pattern=/usr/local/data/tnq/v1.0/train/*.gz\
--output_file=/usr/local/data/tnq/v1.0/train.json
will process all training data and write the results into `train.json`. This
file can, in turn, be provided to squad_eval.py using the --squad argument.
"""
import argparse
import glob
import gzip
import json
import logging
import os
import re
# Dropped samples
n_yn = 0
n_ms = 0
n_non_p = 0
n_long_ans_only = 0
n_error = 0
# Dropped annotations
n_long_ans = 0
n_no_ans = 0
n_short = 0
def clean_text(start_token, end_token, doc_tokens, doc_bytes, ignore_final_whitespace=True):
"""Remove HTML tags from a text span and reconstruct proper spacing."""
text = ""
for index in range(start_token, end_token):
token = doc_tokens[index]
if token["html_token"]:
continue
text += token["token"]
# Add a single space between two tokens iff there is at least one
# whitespace character between them (outside of an HTML tag). For example:
#
# token1 token2 ==> Add space.
# token1</B> <B>token2 ==> Add space.
# token1</A>token2 ==> No space.
# token1<A href="..." title="...">token2 ==> No space.
# token1<SUP>2</SUP>token2 ==> No space.
next_token = token
last_index = end_token if ignore_final_whitespace else end_token + 1
for next_token in doc_tokens[index + 1 : last_index]:
if not next_token["html_token"]:
break
chars = doc_bytes[token["end_byte"] : next_token["start_byte"]].decode("utf-8")
# Since some HTML tags are missing from the token list, we count '<' and
# '>' to detect if we're inside a tag.
unclosed_brackets = 0
for char in chars:
if char == "<":
unclosed_brackets += 1
elif char == ">":
unclosed_brackets -= 1
elif unclosed_brackets == 0 and re.match(r"\s", char):
# Add a single space after this token.
text += " "
break
return text
def get_anno_type(annotation):
long_answer = annotation["long_answer"]
short_answers = annotation["short_answers"]
yes_no_answer = annotation["yes_no_answer"]
if len(short_answers) > 1:
return "multi_short"
elif yes_no_answer != "NONE":
return yes_no_answer
elif len(short_answers) == 1:
return "short_answer"
elif len(short_answers) == 0:
if long_answer["start_token"] == -1:
return "no_answer"
else:
return "long_answer"
def reduce_annotations(anno_types, answers):
"""
In cases where there is annotator disagreement, this fn picks either only the short_answers or only the no_answers,
depending on which is more numerous, with a bias towards picking short_answers.
Note: By this stage, all long_answer annotations and all samples with yes/no answer have been removed.
This leaves just no_answer and short_answers"""
for at in set(anno_types):
assert at in ("no_answer", "short_answer")
if anno_types.count("short_answer") >= anno_types.count("no_answer"):
majority = "short_answer"
is_impossible = False
else:
majority = "no_answer"
is_impossible = True
answers = [a for at, a in zip(anno_types, answers) if at == majority]
reduction = len(anno_types) - len(answers)
assert reduction < 3
if not is_impossible:
global n_no_ans
n_no_ans += reduction
else:
global n_short
n_short += reduction
answers = []
return answers, is_impossible
def nq_to_squad(record):
"""Convert a Natural Questions record to SQuAD format."""
doc_bytes = record["document_html"].encode("utf-8")
doc_tokens = record["document_tokens"]
question_text = record["question_text"]
question_text = question_text[0].upper() + question_text[1:] + "?"
answers = []
anno_types = []
for annotation in record["annotations"]:
anno_type = get_anno_type(annotation)
long_answer = annotation["long_answer"]
short_answers = annotation["short_answers"]
if anno_type.lower() in ["yes", "no"]:
global n_yn
n_yn += 1
return
# Skip examples that don't have exactly one short answer.
# Note: Consider including multi-span short answers.
if anno_type == "multi_short":
global n_ms
n_ms += 1
return
elif anno_type == "short_answer":
short_answer = short_answers[0]
# Skip examples corresponding to HTML blocks other than <P>.
long_answer_html_tag = doc_tokens[long_answer["start_token"]]["token"]
if long_answer_html_tag != "<P>":
global n_non_p
n_non_p += 1
return
answer = clean_text(short_answer["start_token"], short_answer["end_token"], doc_tokens, doc_bytes)
before_answer = clean_text(
0, short_answer["start_token"], doc_tokens, doc_bytes, ignore_final_whitespace=False
)
elif anno_type == "no_answer":
answer = ""
before_answer = ""
# Throw out long answer annotations
elif anno_type == "long_answer":
global n_long_ans
n_long_ans += 1
continue
anno_types.append(anno_type)
answer = {"answer_start": len(before_answer), "text": answer}
answers.append(answer)
if len(answers) == 0:
global n_long_ans_only
n_long_ans_only += 1
return
answers, is_impossible = reduce_annotations(anno_types, answers)
paragraph = clean_text(0, len(doc_tokens), doc_tokens, doc_bytes)
return {
"title": record["document_title"],
"paragraphs": [
{
"context": paragraph,
"qas": [
{
"answers": answers,
"id": record["example_id"],
"question": question_text,
"is_impossible": is_impossible,
}
],
}
],
}
def main():
parser = argparse.ArgumentParser(description="Convert the Natural Questions to SQuAD JSON format.")
parser.add_argument(
"--data_pattern",
dest="data_pattern",
help=("A file pattern to match the Natural Questions " "dataset."),
metavar="PATTERN",
required=True,
)
parser.add_argument(
"--version", dest="version", help="The version label in the output file.", metavar="LABEL", default="nq-train"
)
parser.add_argument(
"--output_file",
dest="output_file",
help="The name of the SQuAD JSON formatted output file.",
metavar="FILE",
default="nq_as_squad.json",
)
args = parser.parse_args()
root = logging.getLogger()
root.setLevel(logging.DEBUG)
records = 0
nq_as_squad = {"version": args.version, "data": []}
for file in sorted(glob.iglob(args.data_pattern)):
logging.info("opening %s", file)
with gzip.GzipFile(file, "r") as f:
for line in f:
records += 1
nq_record = json.loads(line)
try:
squad_record = nq_to_squad(nq_record)
except:
squad_record = None
global n_error
n_error += 1
if squad_record:
nq_as_squad["data"].append(squad_record)
if records % 100 == 0:
logging.info("processed %s records", records)
print("Converted %s NQ records into %s SQuAD records." % (records, len(nq_as_squad["data"])))
print(
f"Removed samples: yes/no: {n_yn} multi_short: {n_ms} non_para {n_non_p} long_ans_only: {n_long_ans_only} errors: {n_error}"
)
print(f"Removed annotations: long_answer: {n_long_ans} short_answer: {n_short} no_answer: ~{n_no_ans}")
with open(args.output_file, "w") as f:
json.dump(nq_as_squad, f, indent=4)
if __name__ == "__main__":
main()