Squad tools (#1029)

* Add first commit

* Add support for conversion to and from pandas df

* Add logging

* Add functionality

* Satisfy mypy

* Incorporate reviewer feedback
This commit is contained in:
Branden Chan 2021-05-06 19:02:15 +02:00 committed by GitHub
parent 373fef8d1e
commit 5d31e633ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

284
haystack/squad_data.py Normal file
View File

@ -0,0 +1,284 @@
import json
import pandas as pd
from tqdm import tqdm
import logging
from typing import Dict, List, Union
import random
from haystack.schema import Document, Label
from farm.data_handler.utils import read_squad_file
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
tqdm.pandas()
COLUMN_NAMES = ["title", "context", "question", "id", "answer_text", "answer_start", "is_impossible"]
class SquadData:
"""This class is designed to manipulate data that is in SQuAD format"""
def __init__(self, squad_data):
"""
:param squad_data: SQuAD format data, either as a dict with a `data` key, or just a list of SQuAD documents
"""
if type(squad_data) == dict:
self.version = squad_data.get("version")
self.data = squad_data["data"]
elif type(squad_data) == list:
self.version = None
self.data = squad_data
self.df = self.to_df(self.data)
def merge_from_file(self, filename: str):
"""Merge the contents of a SQuAD format json file with the data stored in this object"""
new_data = json.load(open(filename))["data"]
self.merge(new_data)
def merge(self, new_data: List):
"""
Merge data in SQuAD format with the data stored in this object
:param new_data: A list of SQuAD document data
"""
df_new = self.to_df(new_data)
self.df = pd.concat([df_new, self.df])
self.data = self.df_to_data(self.df)
@classmethod
def from_file(cls, filename: str):
"""
Create a SquadData object by providing the name of a SQuAD format json file
"""
data = json.load(open(filename))
return cls(data)
def save(self, filename: str):
"""Write the data stored in this object to a json file"""
with open(filename, "w") as f:
squad_data = {"version": self.version, "data": self.data}
json.dump(squad_data, f, indent=2)
def to_dpr_dataset(self):
raise NotImplementedError(
"SquadData.to_dpr_dataset() not yet implemented. "
"For now, have a look at the script at haystack/retriever/squad_to_dpr.py"
)
def to_document_objs(self):
"""Export all paragraphs stored in this object to haystack.Document objects"""
df_docs = self.df[["title", "context"]]
df_docs = df_docs.drop_duplicates()
record_dicts = df_docs.to_dict("records")
documents = [
Document(
text=rd["context"],
id=rd["title"]
) for rd in record_dicts
]
return documents
def to_label_objs(self):
"""Export all labels stored in this object to haystack.Label objects"""
df_labels = self.df[["id", "question", "answer_text", "answer_start"]]
record_dicts = df_labels.to_dict("records")
labels = [
Label(
question=rd["question"],
answer=rd["answer_text"],
is_correct_answer=True,
is_correct_document=True,
id=rd["id"],
origin=rd.get("origin", "SquadData tool"),
document_id=rd.get("document_id", None)
) for rd in record_dicts
]
return labels
@staticmethod
def to_df(data):
"""Convert a list of SQuAD document dictionaries into a pandas dataframe (each row is one annotation)"""
flat = []
for document in data:
title = document["title"]
for paragraph in document["paragraphs"]:
context = paragraph["context"]
for question in paragraph["qas"]:
q = question["question"]
id = question["id"]
is_impossible = question["is_impossible"]
# For no_answer samples
if len(question["answers"]) == 0:
flat.append({"title": title,
"context": context,
"question": q,
"id": id,
"answer_text": "",
"answer_start": None,
"is_impossible": is_impossible})
# For span answer samples
else:
for answer in question["answers"]:
answer_text = answer["text"]
answer_start = answer["answer_start"]
flat.append({"title": title,
"context": context,
"question": q,
"id": id,
"answer_text": answer_text,
"answer_start": answer_start,
"is_impossible": is_impossible})
df = pd.DataFrame.from_records(flat)
return df
def count(self, unit="questions"):
"""
Count the samples in the data. Choose from unit = "paragraphs", "questions", "answers", "no_answers", "span_answers"
"""
c = 0
for document in self.data:
for paragraph in document["paragraphs"]:
if unit == "paragraphs":
c += 1
for question in paragraph["qas"]:
if unit == "questions":
c += 1
# Count no_answers
if len(question["answers"]) == 0:
if unit in ["answers", "no_answers"]:
c += 1
# Count span answers
else:
for answer in question["answers"]:
if unit in ["answers", "span_answers"]:
c += 1
return c
def df_to_data(self, df):
"""Convert a dataframe into SQuAD format data (list of SQuAD document dictionaries)"""
logger.info("Converting data frame to squad format data")
# Aggregate the answers of each question
logger.info("Aggregating the answers of each question")
df_grouped_answers = df.groupby(["title", "context", "question", "id", "is_impossible"])
df_aggregated_answers = df[["title", "context", "question", "id", "is_impossible"]].drop_duplicates().reset_index()
answers = df_grouped_answers.progress_apply(self._aggregate_answers).rename("answers")
answers = pd.DataFrame(answers).reset_index()
df_aggregated_answers = pd.merge(df_aggregated_answers, answers)
# Aggregate the questions of each passage
logger.info("Aggregating the questions of each paragraphs of each document")
df_grouped_questions = df_aggregated_answers.groupby(["title", "context"])
df_aggregated_questions = df[["title", "context"]].drop_duplicates().reset_index()
questions = df_grouped_questions.progress_apply(self._aggregate_questions).rename("qas")
questions = pd.DataFrame(questions).reset_index()
df_aggregated_questions = pd.merge(df_aggregated_questions, questions)
logger.info("Aggregating the paragraphs of each document")
df_grouped_paragraphs = df_aggregated_questions.groupby(["title"])
df_aggregated_paragraphs = df[["title"]].drop_duplicates().reset_index()
paragraphs = df_grouped_paragraphs.progress_apply(self._aggregate_passages).rename("paragraphs")
paragraphs = pd.DataFrame(paragraphs).reset_index()
df_aggregated_paragraphs = pd.merge(df_aggregated_paragraphs, paragraphs)
df_aggregated_paragraphs = df_aggregated_paragraphs[["title", "paragraphs"]]
ret = df_aggregated_paragraphs.to_dict("records")
return ret
@staticmethod
def _aggregate_passages(x):
x = x[["context", "qas"]]
ret = x.to_dict("records")
return ret
@staticmethod
def _aggregate_questions(x):
x = x[["question", "id", "answers", "is_impossible"]]
ret = x.to_dict("records")
return ret
@staticmethod
def _aggregate_answers(x):
x = x[["answer_text", "answer_start"]]
x = x.rename(columns={"answer_text": "text"})
# Span anwser
try:
x["answer_start"] = x["answer_start"].astype(int)
ret = x.to_dict("records")
# No answer
except ValueError:
ret = []
return ret
def set_data(self, data):
self.data = data
self.df = self.to_df(data)
def sample_questions(self, n):
"""
Return a sample of n questions in SQuAD format (list of SQuAD document dictionaries)
Note, that if the same question is asked on multiple different passages, this fn treats that
as a single question
"""
all_questions = self.get_all_questions()
sampled_questions = random.sample(all_questions, n)
df_sampled = self.df[self.df["question"].isin(sampled_questions)]
return self.df_to_data(df_sampled)
def get_all_paragraphs(self):
"""Return all paragraph strings"""
return self.df["context"].unique().tolist()
def get_all_questions(self):
"""Return all question strings. Note that if the same question appears for different paragraphs, it will be
returned multiple times by this fn"""
df_questions = self.df[["title", "context", "question"]]
df_questions = df_questions.drop_duplicates()
questions = df_questions["question"].tolist()
return questions
def get_all_document_titles(self):
"""Return all document title strings"""
return self.df["title"].unique().tolist()
if __name__ == "__main__":
# Download the SQuAD dataset if it isn't at target directory
read_squad_file( "../data/squad20/train-v2.0.json")
filename1 = "../data/squad20/train-v2.0.json"
filename2 = "../data/squad20/dev-v2.0.json"
# Load file1 and take a sample of 10000 questions
sd = SquadData.from_file(filename1)
sample1 = sd.sample_questions(n=10000)
# Set sd to now contain the sample of 10000 questions
sd.set_data(sample1)
# Merge sd with file2 and take a sample of 100 questions
sd.merge_from_file(filename2)
sample2 = sd.sample_questions(n=100)
sd.set_data(sample2)
# Save this sample of 100
sd.save("../data/squad20/sample.json")
paragraphs = sd.get_all_paragraphs()
questions = sd.get_all_questions()
titles = sd.get_all_document_titles()
documents = sd.to_document_objs()
labels = sd.to_label_objs()
n_qs = sd.count(unit="questions")
n_as = sd.count(unit="no_answers")
n_ps = sd.count(unit="paragraphs")
print(n_qs)
print(n_as)
print(n_ps)