248 lines
10 KiB
Python

import json
import logging
import os
import random
import tarfile
import tempfile
import uuid
from itertools import islice
from pathlib import Path
import requests
from tqdm import tqdm
logger = logging.getLogger(__name__)
DOWNSTREAM_TASK_MAP = {
"squad20": "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-downstream/squad20.tar.gz",
"covidqa": "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-downstream/covidqa.tar.gz",
}
def read_dpr_json(file, max_samples=None, proxies=None, num_hard_negatives=1, num_positives=1, shuffle_negatives=True, shuffle_positives=False):
"""
Reads a Dense Passage Retrieval (DPR) data file in json format and returns a list of dictionaries.
:param file: filename of DPR data in json format
Returns:
list of dictionaries: List[dict]
each dictionary: {
"query": str -> query_text
"passages": List[dictionaries] -> [{"text": document_text, "title": xxx, "label": "positive", "external_id": abb123},
{"text": document_text, "title": xxx, "label": "hard_negative", "external_id": abb134},
...]
}
example:
["query": 'who sings does he love me with reba'
"passages" : [{'title': 'Does He Love You',
'text': 'Does He Love You "Does He Love You" is a song written by Sandy Knox and Billy Stritch, and recorded as a duet by American country music artists Reba McEntire and Linda Davis. It was released in August 1993 as the first single from Reba\'s album "Greatest Hits Volume Two". It is one of country music\'s several songs about a love triangle. "Does He Love You" was written in 1982 by Billy Stritch. He recorded it with a trio in which he performed at the time, because he wanted a song that could be sung by the other two members',
'label': 'positive',
'external_id': '11828866'},
{'title': 'When the Nightingale Sings',
'text': "When the Nightingale Sings When The Nightingale Sings is a Middle English poem, author unknown, recorded in the British Library's Harley 2253 manuscript, verse 25. It is a love poem, extolling the beauty and lost love of an unknown maiden. When þe nyhtegale singes þe wodes waxen grene.<br> Lef ant gras ant blosme springes in aueryl y wene,<br> Ant love is to myn herte gon wiþ one spere so kene<br> Nyht ant day my blod hit drynkes myn herte deþ me tene. Ich have loved al þis er þat y may love namore,<br> Ich have siked moni syk lemmon for",
'label': 'hard_negative',
'external_id': '10891637'}]
]
"""
# get remote dataset if needed
if not (os.path.exists(file)):
logger.info(f" Couldn't find {file} locally. Trying to download ...")
_download_extract_downstream_data(file, proxies=proxies)
if file.suffix.lower() == ".jsonl":
dicts = []
with open(file, encoding='utf-8') as f:
for line in f:
dicts.append(json.loads(line))
else:
dicts = json.load(open(file, encoding='utf-8'))
if max_samples:
dicts = random.sample(dicts, min(max_samples, len(dicts)))
# convert DPR dictionary to standard dictionary
query_json_keys = ["question", "questions", "query"]
positive_context_json_keys = ["positive_contexts", "positive_ctxs", "positive_context", "positive_ctx"]
hard_negative_json_keys = ["hard_negative_contexts", "hard_negative_ctxs", "hard_negative_context", "hard_negative_ctx"]
standard_dicts = []
for dict in dicts:
sample = {}
passages = []
for key, val in dict.items():
if key in query_json_keys:
sample["query"] = val
elif key in positive_context_json_keys:
if shuffle_positives:
random.shuffle(val)
for passage in val[:num_positives]:
passages.append({
"title": passage["title"],
"text": passage["text"],
"label": "positive",
"external_id": passage.get("passage_id", uuid.uuid4().hex.upper()[0:8])
})
elif key in hard_negative_json_keys:
if shuffle_negatives:
random.shuffle(val)
for passage in val[:num_hard_negatives]:
passages.append({
"title": passage["title"],
"text": passage["text"],
"label": "hard_negative",
"external_id": passage.get("passage_id", uuid.uuid4().hex.upper()[0:8])
})
sample["passages"] = passages
standard_dicts.append(sample)
return standard_dicts
def read_squad_file(filename, proxies=None):
"""Read a SQuAD json file"""
if not (os.path.exists(filename)):
logger.info(f" Couldn't find {filename} locally. Trying to download ...")
_download_extract_downstream_data(filename, proxies)
with open(filename, "r", encoding="utf-8") as reader:
input_data = json.load(reader)["data"]
return input_data
def write_squad_predictions(predictions, out_filename, predictions_filename=None):
predictions_json = {}
for x in predictions:
for p in x["predictions"]:
if p["answers"][0]["answer"] is not None:
predictions_json[p["question_id"]] = p["answers"][0]["answer"]
else:
predictions_json[p["question_id"]] = "" #convert No answer = None to format understood by the SQuAD eval script
if predictions_filename:
dev_labels = {}
temp = json.load(open(predictions_filename, "r"))
for d in temp["data"]:
for p in d["paragraphs"]:
for q in p["qas"]:
if q.get("is_impossible",False):
dev_labels[q["id"]] = "is_impossible"
else:
dev_labels[q["id"]] = q["answers"][0]["text"]
not_included = set(list(dev_labels.keys())) - set(list(predictions_json.keys()))
if len(not_included) > 0:
logger.info(f"There were missing predicitons for question ids: {list(not_included)}")
for x in not_included:
predictions_json[x] = ""
# os.makedirs("model_output", exist_ok=True)
# filepath = Path("model_output") / out_filename
json.dump(predictions_json, open(out_filename, "w"))
logger.info(f"Written Squad predictions to: {out_filename}")
def http_get(url, temp_file, proxies=None):
req = requests.get(url, stream=True, proxies=proxies)
content_length = req.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def grouper(iterable, n, worker_id=0, total_workers=1):
"""
Split an iterable into a list of n-sized chunks. Each element in the chunk is a tuple of (index_num, element).
Example:
list(grouper('ABCDEFG', 3))
[[(0, 'A'), (1, 'B'), (2, 'C')], [(3, 'D'), (4, 'E'), (5, 'F')], [(6, 'G')]]
Use with the StreamingDataSilo
When StreamingDataSilo is used with multiple PyTorch DataLoader workers, the generator
yielding dicts(that gets converted to datasets) is replicated across the workers.
To avoid duplicates, we split the dicts across workers by creating a new generator for
each worker using this method.
Input --> [dictA, dictB, dictC, dictD, dictE, ...] with total worker=3 and n=2
Output for worker 1: [(dictA, dictB), (dictG, dictH), ...]
Output for worker 2: [(dictC, dictD), (dictI, dictJ), ...]
Output for worker 3: [(dictE, dictF), (dictK, dictL), ...]
This method also adds an index number to every dict yielded.
:param iterable: a generator object that yields dicts
:type iterable: generator
:param n: the dicts are grouped in n-sized chunks that gets converted to datasets
:type n: int
:param worker_id: the worker_id for the PyTorch DataLoader
:type worker_id: int
:param total_workers: total number of workers for the PyTorch DataLoader
:type total_workers: int
"""
# TODO make me comprehensible :)
def get_iter_start_pos(gen):
start_pos = worker_id * n
for i in gen:
if start_pos:
start_pos -= 1
continue
yield i
def filter_elements_per_worker(gen):
x = n
y = (total_workers - 1) * n
for i in gen:
if x:
yield i
x -= 1
else:
if y != 1:
y -= 1
continue
else:
x = n
y = (total_workers - 1) * n
iterable = iter(enumerate(iterable))
iterable = get_iter_start_pos(iterable)
if total_workers > 1:
iterable = filter_elements_per_worker(iterable)
return iter(lambda: list(islice(iterable, n)), [])
def _download_extract_downstream_data(input_file, proxies=None):
# download archive to temp dir and extract to correct position
full_path = Path(os.path.realpath(input_file))
directory = full_path.parent
taskname = directory.stem
datadir = directory.parent
logger.info(
"downloading and extracting file {} to dir {}".format(taskname, datadir)
)
if taskname not in DOWNSTREAM_TASK_MAP:
logger.error("Cannot download {}. Unknown data source.".format(taskname))
else:
if os.name == "nt": # make use of NamedTemporaryFile compatible with Windows
delete_tmp_file = False
else:
delete_tmp_file = True
with tempfile.NamedTemporaryFile(delete=delete_tmp_file) as temp_file:
http_get(DOWNSTREAM_TASK_MAP[taskname], temp_file, proxies=proxies)
temp_file.flush()
temp_file.seek(0) # making tempfile accessible
tfile = tarfile.open(temp_file.name)
tfile.extractall(datadir)
# temp_file gets deleted here
def is_json(x):
if issubclass(type(x), Path):
return True
try:
json.dumps(x)
return True
except:
return False