mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-13 08:33:57 +00:00
Squad tools (#1029)
* Add first commit * Add support for conversion to and from pandas df * Add logging * Add functionality * Satisfy mypy * Incorporate reviewer feedback
This commit is contained in:
parent
373fef8d1e
commit
5d31e633ce
284
haystack/squad_data.py
Normal file
284
haystack/squad_data.py
Normal file
@ -0,0 +1,284 @@
|
|||||||
|
import json
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
import random
|
||||||
|
|
||||||
|
from haystack.schema import Document, Label
|
||||||
|
|
||||||
|
from farm.data_handler.utils import read_squad_file
|
||||||
|
|
||||||
|
logging.basicConfig()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
tqdm.pandas()
|
||||||
|
|
||||||
|
COLUMN_NAMES = ["title", "context", "question", "id", "answer_text", "answer_start", "is_impossible"]
|
||||||
|
|
||||||
|
class SquadData:
|
||||||
|
"""This class is designed to manipulate data that is in SQuAD format"""
|
||||||
|
def __init__(self, squad_data):
|
||||||
|
"""
|
||||||
|
:param squad_data: SQuAD format data, either as a dict with a `data` key, or just a list of SQuAD documents
|
||||||
|
"""
|
||||||
|
if type(squad_data) == dict:
|
||||||
|
self.version = squad_data.get("version")
|
||||||
|
self.data = squad_data["data"]
|
||||||
|
elif type(squad_data) == list:
|
||||||
|
self.version = None
|
||||||
|
self.data = squad_data
|
||||||
|
self.df = self.to_df(self.data)
|
||||||
|
|
||||||
|
def merge_from_file(self, filename: str):
|
||||||
|
"""Merge the contents of a SQuAD format json file with the data stored in this object"""
|
||||||
|
new_data = json.load(open(filename))["data"]
|
||||||
|
self.merge(new_data)
|
||||||
|
|
||||||
|
def merge(self, new_data: List):
|
||||||
|
"""
|
||||||
|
Merge data in SQuAD format with the data stored in this object
|
||||||
|
:param new_data: A list of SQuAD document data
|
||||||
|
"""
|
||||||
|
df_new = self.to_df(new_data)
|
||||||
|
self.df = pd.concat([df_new, self.df])
|
||||||
|
self.data = self.df_to_data(self.df)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_file(cls, filename: str):
|
||||||
|
"""
|
||||||
|
Create a SquadData object by providing the name of a SQuAD format json file
|
||||||
|
"""
|
||||||
|
data = json.load(open(filename))
|
||||||
|
return cls(data)
|
||||||
|
|
||||||
|
def save(self, filename: str):
|
||||||
|
"""Write the data stored in this object to a json file"""
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
squad_data = {"version": self.version, "data": self.data}
|
||||||
|
json.dump(squad_data, f, indent=2)
|
||||||
|
|
||||||
|
def to_dpr_dataset(self):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"SquadData.to_dpr_dataset() not yet implemented. "
|
||||||
|
"For now, have a look at the script at haystack/retriever/squad_to_dpr.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_document_objs(self):
|
||||||
|
"""Export all paragraphs stored in this object to haystack.Document objects"""
|
||||||
|
df_docs = self.df[["title", "context"]]
|
||||||
|
df_docs = df_docs.drop_duplicates()
|
||||||
|
record_dicts = df_docs.to_dict("records")
|
||||||
|
documents = [
|
||||||
|
Document(
|
||||||
|
text=rd["context"],
|
||||||
|
id=rd["title"]
|
||||||
|
) for rd in record_dicts
|
||||||
|
]
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def to_label_objs(self):
|
||||||
|
"""Export all labels stored in this object to haystack.Label objects"""
|
||||||
|
df_labels = self.df[["id", "question", "answer_text", "answer_start"]]
|
||||||
|
record_dicts = df_labels.to_dict("records")
|
||||||
|
labels = [
|
||||||
|
Label(
|
||||||
|
question=rd["question"],
|
||||||
|
answer=rd["answer_text"],
|
||||||
|
is_correct_answer=True,
|
||||||
|
is_correct_document=True,
|
||||||
|
id=rd["id"],
|
||||||
|
origin=rd.get("origin", "SquadData tool"),
|
||||||
|
document_id=rd.get("document_id", None)
|
||||||
|
) for rd in record_dicts
|
||||||
|
]
|
||||||
|
return labels
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_df(data):
|
||||||
|
"""Convert a list of SQuAD document dictionaries into a pandas dataframe (each row is one annotation)"""
|
||||||
|
flat = []
|
||||||
|
for document in data:
|
||||||
|
title = document["title"]
|
||||||
|
for paragraph in document["paragraphs"]:
|
||||||
|
context = paragraph["context"]
|
||||||
|
for question in paragraph["qas"]:
|
||||||
|
q = question["question"]
|
||||||
|
id = question["id"]
|
||||||
|
is_impossible = question["is_impossible"]
|
||||||
|
# For no_answer samples
|
||||||
|
if len(question["answers"]) == 0:
|
||||||
|
flat.append({"title": title,
|
||||||
|
"context": context,
|
||||||
|
"question": q,
|
||||||
|
"id": id,
|
||||||
|
"answer_text": "",
|
||||||
|
"answer_start": None,
|
||||||
|
"is_impossible": is_impossible})
|
||||||
|
# For span answer samples
|
||||||
|
else:
|
||||||
|
for answer in question["answers"]:
|
||||||
|
answer_text = answer["text"]
|
||||||
|
answer_start = answer["answer_start"]
|
||||||
|
flat.append({"title": title,
|
||||||
|
"context": context,
|
||||||
|
"question": q,
|
||||||
|
"id": id,
|
||||||
|
"answer_text": answer_text,
|
||||||
|
"answer_start": answer_start,
|
||||||
|
"is_impossible": is_impossible})
|
||||||
|
df = pd.DataFrame.from_records(flat)
|
||||||
|
return df
|
||||||
|
|
||||||
|
def count(self, unit="questions"):
|
||||||
|
"""
|
||||||
|
Count the samples in the data. Choose from unit = "paragraphs", "questions", "answers", "no_answers", "span_answers"
|
||||||
|
"""
|
||||||
|
c = 0
|
||||||
|
for document in self.data:
|
||||||
|
for paragraph in document["paragraphs"]:
|
||||||
|
if unit == "paragraphs":
|
||||||
|
c += 1
|
||||||
|
for question in paragraph["qas"]:
|
||||||
|
if unit == "questions":
|
||||||
|
c += 1
|
||||||
|
# Count no_answers
|
||||||
|
if len(question["answers"]) == 0:
|
||||||
|
if unit in ["answers", "no_answers"]:
|
||||||
|
c += 1
|
||||||
|
# Count span answers
|
||||||
|
else:
|
||||||
|
for answer in question["answers"]:
|
||||||
|
if unit in ["answers", "span_answers"]:
|
||||||
|
c += 1
|
||||||
|
return c
|
||||||
|
|
||||||
|
def df_to_data(self, df):
|
||||||
|
"""Convert a dataframe into SQuAD format data (list of SQuAD document dictionaries)"""
|
||||||
|
|
||||||
|
logger.info("Converting data frame to squad format data")
|
||||||
|
|
||||||
|
# Aggregate the answers of each question
|
||||||
|
logger.info("Aggregating the answers of each question")
|
||||||
|
df_grouped_answers = df.groupby(["title", "context", "question", "id", "is_impossible"])
|
||||||
|
df_aggregated_answers = df[["title", "context", "question", "id", "is_impossible"]].drop_duplicates().reset_index()
|
||||||
|
answers = df_grouped_answers.progress_apply(self._aggregate_answers).rename("answers")
|
||||||
|
answers = pd.DataFrame(answers).reset_index()
|
||||||
|
df_aggregated_answers = pd.merge(df_aggregated_answers, answers)
|
||||||
|
|
||||||
|
# Aggregate the questions of each passage
|
||||||
|
logger.info("Aggregating the questions of each paragraphs of each document")
|
||||||
|
df_grouped_questions = df_aggregated_answers.groupby(["title", "context"])
|
||||||
|
df_aggregated_questions = df[["title", "context"]].drop_duplicates().reset_index()
|
||||||
|
questions = df_grouped_questions.progress_apply(self._aggregate_questions).rename("qas")
|
||||||
|
questions = pd.DataFrame(questions).reset_index()
|
||||||
|
df_aggregated_questions = pd.merge(df_aggregated_questions, questions)
|
||||||
|
|
||||||
|
logger.info("Aggregating the paragraphs of each document")
|
||||||
|
df_grouped_paragraphs = df_aggregated_questions.groupby(["title"])
|
||||||
|
df_aggregated_paragraphs = df[["title"]].drop_duplicates().reset_index()
|
||||||
|
paragraphs = df_grouped_paragraphs.progress_apply(self._aggregate_passages).rename("paragraphs")
|
||||||
|
paragraphs = pd.DataFrame(paragraphs).reset_index()
|
||||||
|
df_aggregated_paragraphs = pd.merge(df_aggregated_paragraphs, paragraphs)
|
||||||
|
|
||||||
|
df_aggregated_paragraphs = df_aggregated_paragraphs[["title", "paragraphs"]]
|
||||||
|
ret = df_aggregated_paragraphs.to_dict("records")
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _aggregate_passages(x):
|
||||||
|
x = x[["context", "qas"]]
|
||||||
|
ret = x.to_dict("records")
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _aggregate_questions(x):
|
||||||
|
x = x[["question", "id", "answers", "is_impossible"]]
|
||||||
|
ret = x.to_dict("records")
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _aggregate_answers(x):
|
||||||
|
x = x[["answer_text", "answer_start"]]
|
||||||
|
x = x.rename(columns={"answer_text": "text"})
|
||||||
|
|
||||||
|
# Span anwser
|
||||||
|
try:
|
||||||
|
x["answer_start"] = x["answer_start"].astype(int)
|
||||||
|
ret = x.to_dict("records")
|
||||||
|
|
||||||
|
# No answer
|
||||||
|
except ValueError:
|
||||||
|
ret = []
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def set_data(self, data):
|
||||||
|
self.data = data
|
||||||
|
self.df = self.to_df(data)
|
||||||
|
|
||||||
|
def sample_questions(self, n):
|
||||||
|
"""
|
||||||
|
Return a sample of n questions in SQuAD format (list of SQuAD document dictionaries)
|
||||||
|
Note, that if the same question is asked on multiple different passages, this fn treats that
|
||||||
|
as a single question
|
||||||
|
"""
|
||||||
|
all_questions = self.get_all_questions()
|
||||||
|
sampled_questions = random.sample(all_questions, n)
|
||||||
|
df_sampled = self.df[self.df["question"].isin(sampled_questions)]
|
||||||
|
return self.df_to_data(df_sampled)
|
||||||
|
|
||||||
|
def get_all_paragraphs(self):
|
||||||
|
"""Return all paragraph strings"""
|
||||||
|
return self.df["context"].unique().tolist()
|
||||||
|
|
||||||
|
def get_all_questions(self):
|
||||||
|
"""Return all question strings. Note that if the same question appears for different paragraphs, it will be
|
||||||
|
returned multiple times by this fn"""
|
||||||
|
df_questions = self.df[["title", "context", "question"]]
|
||||||
|
df_questions = df_questions.drop_duplicates()
|
||||||
|
questions = df_questions["question"].tolist()
|
||||||
|
return questions
|
||||||
|
|
||||||
|
def get_all_document_titles(self):
|
||||||
|
"""Return all document title strings"""
|
||||||
|
return self.df["title"].unique().tolist()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Download the SQuAD dataset if it isn't at target directory
|
||||||
|
read_squad_file( "../data/squad20/train-v2.0.json")
|
||||||
|
|
||||||
|
filename1 = "../data/squad20/train-v2.0.json"
|
||||||
|
filename2 = "../data/squad20/dev-v2.0.json"
|
||||||
|
|
||||||
|
# Load file1 and take a sample of 10000 questions
|
||||||
|
sd = SquadData.from_file(filename1)
|
||||||
|
sample1 = sd.sample_questions(n=10000)
|
||||||
|
|
||||||
|
# Set sd to now contain the sample of 10000 questions
|
||||||
|
sd.set_data(sample1)
|
||||||
|
|
||||||
|
# Merge sd with file2 and take a sample of 100 questions
|
||||||
|
sd.merge_from_file(filename2)
|
||||||
|
sample2 = sd.sample_questions(n=100)
|
||||||
|
sd.set_data(sample2)
|
||||||
|
|
||||||
|
# Save this sample of 100
|
||||||
|
sd.save("../data/squad20/sample.json")
|
||||||
|
|
||||||
|
paragraphs = sd.get_all_paragraphs()
|
||||||
|
questions = sd.get_all_questions()
|
||||||
|
titles = sd.get_all_document_titles()
|
||||||
|
|
||||||
|
documents = sd.to_document_objs()
|
||||||
|
labels = sd.to_label_objs()
|
||||||
|
|
||||||
|
n_qs = sd.count(unit="questions")
|
||||||
|
n_as = sd.count(unit="no_answers")
|
||||||
|
n_ps = sd.count(unit="paragraphs")
|
||||||
|
|
||||||
|
print(n_qs)
|
||||||
|
print(n_as)
|
||||||
|
print(n_ps)
|
||||||
Loading…
x
Reference in New Issue
Block a user