From 5d31e633cef1e22b4ddbc57aba7b34c80f36e8ec Mon Sep 17 00:00:00 2001 From: Branden Chan <33759007+brandenchan@users.noreply.github.com> Date: Thu, 6 May 2021 19:02:15 +0200 Subject: [PATCH] Squad tools (#1029) * Add first commit * Add support for conversion to and from pandas df * Add logging * Add functionality * Satisfy mypy * Incorporate reviewer feedback --- haystack/squad_data.py | 284 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 284 insertions(+) create mode 100644 haystack/squad_data.py diff --git a/haystack/squad_data.py b/haystack/squad_data.py new file mode 100644 index 000000000..6251df86e --- /dev/null +++ b/haystack/squad_data.py @@ -0,0 +1,284 @@ +import json +import pandas as pd +from tqdm import tqdm +import logging +from typing import Dict, List, Union +import random + +from haystack.schema import Document, Label + +from farm.data_handler.utils import read_squad_file + +logging.basicConfig() +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +tqdm.pandas() + +COLUMN_NAMES = ["title", "context", "question", "id", "answer_text", "answer_start", "is_impossible"] + +class SquadData: + """This class is designed to manipulate data that is in SQuAD format""" + def __init__(self, squad_data): + """ + :param squad_data: SQuAD format data, either as a dict with a `data` key, or just a list of SQuAD documents + """ + if type(squad_data) == dict: + self.version = squad_data.get("version") + self.data = squad_data["data"] + elif type(squad_data) == list: + self.version = None + self.data = squad_data + self.df = self.to_df(self.data) + + def merge_from_file(self, filename: str): + """Merge the contents of a SQuAD format json file with the data stored in this object""" + new_data = json.load(open(filename))["data"] + self.merge(new_data) + + def merge(self, new_data: List): + """ + Merge data in SQuAD format with the data stored in this object + :param new_data: A list of SQuAD document data + """ + df_new = self.to_df(new_data) + self.df = pd.concat([df_new, self.df]) + self.data = self.df_to_data(self.df) + + @classmethod + def from_file(cls, filename: str): + """ + Create a SquadData object by providing the name of a SQuAD format json file + """ + data = json.load(open(filename)) + return cls(data) + + def save(self, filename: str): + """Write the data stored in this object to a json file""" + with open(filename, "w") as f: + squad_data = {"version": self.version, "data": self.data} + json.dump(squad_data, f, indent=2) + + def to_dpr_dataset(self): + raise NotImplementedError( + "SquadData.to_dpr_dataset() not yet implemented. " + "For now, have a look at the script at haystack/retriever/squad_to_dpr.py" + ) + + def to_document_objs(self): + """Export all paragraphs stored in this object to haystack.Document objects""" + df_docs = self.df[["title", "context"]] + df_docs = df_docs.drop_duplicates() + record_dicts = df_docs.to_dict("records") + documents = [ + Document( + text=rd["context"], + id=rd["title"] + ) for rd in record_dicts + ] + return documents + + def to_label_objs(self): + """Export all labels stored in this object to haystack.Label objects""" + df_labels = self.df[["id", "question", "answer_text", "answer_start"]] + record_dicts = df_labels.to_dict("records") + labels = [ + Label( + question=rd["question"], + answer=rd["answer_text"], + is_correct_answer=True, + is_correct_document=True, + id=rd["id"], + origin=rd.get("origin", "SquadData tool"), + document_id=rd.get("document_id", None) + ) for rd in record_dicts + ] + return labels + + @staticmethod + def to_df(data): + """Convert a list of SQuAD document dictionaries into a pandas dataframe (each row is one annotation)""" + flat = [] + for document in data: + title = document["title"] + for paragraph in document["paragraphs"]: + context = paragraph["context"] + for question in paragraph["qas"]: + q = question["question"] + id = question["id"] + is_impossible = question["is_impossible"] + # For no_answer samples + if len(question["answers"]) == 0: + flat.append({"title": title, + "context": context, + "question": q, + "id": id, + "answer_text": "", + "answer_start": None, + "is_impossible": is_impossible}) + # For span answer samples + else: + for answer in question["answers"]: + answer_text = answer["text"] + answer_start = answer["answer_start"] + flat.append({"title": title, + "context": context, + "question": q, + "id": id, + "answer_text": answer_text, + "answer_start": answer_start, + "is_impossible": is_impossible}) + df = pd.DataFrame.from_records(flat) + return df + + def count(self, unit="questions"): + """ + Count the samples in the data. Choose from unit = "paragraphs", "questions", "answers", "no_answers", "span_answers" + """ + c = 0 + for document in self.data: + for paragraph in document["paragraphs"]: + if unit == "paragraphs": + c += 1 + for question in paragraph["qas"]: + if unit == "questions": + c += 1 + # Count no_answers + if len(question["answers"]) == 0: + if unit in ["answers", "no_answers"]: + c += 1 + # Count span answers + else: + for answer in question["answers"]: + if unit in ["answers", "span_answers"]: + c += 1 + return c + + def df_to_data(self, df): + """Convert a dataframe into SQuAD format data (list of SQuAD document dictionaries)""" + + logger.info("Converting data frame to squad format data") + + # Aggregate the answers of each question + logger.info("Aggregating the answers of each question") + df_grouped_answers = df.groupby(["title", "context", "question", "id", "is_impossible"]) + df_aggregated_answers = df[["title", "context", "question", "id", "is_impossible"]].drop_duplicates().reset_index() + answers = df_grouped_answers.progress_apply(self._aggregate_answers).rename("answers") + answers = pd.DataFrame(answers).reset_index() + df_aggregated_answers = pd.merge(df_aggregated_answers, answers) + + # Aggregate the questions of each passage + logger.info("Aggregating the questions of each paragraphs of each document") + df_grouped_questions = df_aggregated_answers.groupby(["title", "context"]) + df_aggregated_questions = df[["title", "context"]].drop_duplicates().reset_index() + questions = df_grouped_questions.progress_apply(self._aggregate_questions).rename("qas") + questions = pd.DataFrame(questions).reset_index() + df_aggregated_questions = pd.merge(df_aggregated_questions, questions) + + logger.info("Aggregating the paragraphs of each document") + df_grouped_paragraphs = df_aggregated_questions.groupby(["title"]) + df_aggregated_paragraphs = df[["title"]].drop_duplicates().reset_index() + paragraphs = df_grouped_paragraphs.progress_apply(self._aggregate_passages).rename("paragraphs") + paragraphs = pd.DataFrame(paragraphs).reset_index() + df_aggregated_paragraphs = pd.merge(df_aggregated_paragraphs, paragraphs) + + df_aggregated_paragraphs = df_aggregated_paragraphs[["title", "paragraphs"]] + ret = df_aggregated_paragraphs.to_dict("records") + + return ret + + @staticmethod + def _aggregate_passages(x): + x = x[["context", "qas"]] + ret = x.to_dict("records") + return ret + + @staticmethod + def _aggregate_questions(x): + x = x[["question", "id", "answers", "is_impossible"]] + ret = x.to_dict("records") + return ret + + @staticmethod + def _aggregate_answers(x): + x = x[["answer_text", "answer_start"]] + x = x.rename(columns={"answer_text": "text"}) + + # Span anwser + try: + x["answer_start"] = x["answer_start"].astype(int) + ret = x.to_dict("records") + + # No answer + except ValueError: + ret = [] + + return ret + + def set_data(self, data): + self.data = data + self.df = self.to_df(data) + + def sample_questions(self, n): + """ + Return a sample of n questions in SQuAD format (list of SQuAD document dictionaries) + Note, that if the same question is asked on multiple different passages, this fn treats that + as a single question + """ + all_questions = self.get_all_questions() + sampled_questions = random.sample(all_questions, n) + df_sampled = self.df[self.df["question"].isin(sampled_questions)] + return self.df_to_data(df_sampled) + + def get_all_paragraphs(self): + """Return all paragraph strings""" + return self.df["context"].unique().tolist() + + def get_all_questions(self): + """Return all question strings. Note that if the same question appears for different paragraphs, it will be + returned multiple times by this fn""" + df_questions = self.df[["title", "context", "question"]] + df_questions = df_questions.drop_duplicates() + questions = df_questions["question"].tolist() + return questions + + def get_all_document_titles(self): + """Return all document title strings""" + return self.df["title"].unique().tolist() + +if __name__ == "__main__": + # Download the SQuAD dataset if it isn't at target directory + read_squad_file( "../data/squad20/train-v2.0.json") + + filename1 = "../data/squad20/train-v2.0.json" + filename2 = "../data/squad20/dev-v2.0.json" + + # Load file1 and take a sample of 10000 questions + sd = SquadData.from_file(filename1) + sample1 = sd.sample_questions(n=10000) + + # Set sd to now contain the sample of 10000 questions + sd.set_data(sample1) + + # Merge sd with file2 and take a sample of 100 questions + sd.merge_from_file(filename2) + sample2 = sd.sample_questions(n=100) + sd.set_data(sample2) + + # Save this sample of 100 + sd.save("../data/squad20/sample.json") + + paragraphs = sd.get_all_paragraphs() + questions = sd.get_all_questions() + titles = sd.get_all_document_titles() + + documents = sd.to_document_objs() + labels = sd.to_label_objs() + + n_qs = sd.count(unit="questions") + n_as = sd.count(unit="no_answers") + n_ps = sd.count(unit="paragraphs") + + print(n_qs) + print(n_as) + print(n_ps)