mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 18:59:28 +00:00
Squad tools (#1029)
* Add first commit * Add support for conversion to and from pandas df * Add logging * Add functionality * Satisfy mypy * Incorporate reviewer feedback
This commit is contained in:
parent
373fef8d1e
commit
5d31e633ce
284
haystack/squad_data.py
Normal file
284
haystack/squad_data.py
Normal file
@ -0,0 +1,284 @@
|
||||
import json
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
from typing import Dict, List, Union
|
||||
import random
|
||||
|
||||
from haystack.schema import Document, Label
|
||||
|
||||
from farm.data_handler.utils import read_squad_file
|
||||
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
tqdm.pandas()
|
||||
|
||||
COLUMN_NAMES = ["title", "context", "question", "id", "answer_text", "answer_start", "is_impossible"]
|
||||
|
||||
class SquadData:
|
||||
"""This class is designed to manipulate data that is in SQuAD format"""
|
||||
def __init__(self, squad_data):
|
||||
"""
|
||||
:param squad_data: SQuAD format data, either as a dict with a `data` key, or just a list of SQuAD documents
|
||||
"""
|
||||
if type(squad_data) == dict:
|
||||
self.version = squad_data.get("version")
|
||||
self.data = squad_data["data"]
|
||||
elif type(squad_data) == list:
|
||||
self.version = None
|
||||
self.data = squad_data
|
||||
self.df = self.to_df(self.data)
|
||||
|
||||
def merge_from_file(self, filename: str):
|
||||
"""Merge the contents of a SQuAD format json file with the data stored in this object"""
|
||||
new_data = json.load(open(filename))["data"]
|
||||
self.merge(new_data)
|
||||
|
||||
def merge(self, new_data: List):
|
||||
"""
|
||||
Merge data in SQuAD format with the data stored in this object
|
||||
:param new_data: A list of SQuAD document data
|
||||
"""
|
||||
df_new = self.to_df(new_data)
|
||||
self.df = pd.concat([df_new, self.df])
|
||||
self.data = self.df_to_data(self.df)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, filename: str):
|
||||
"""
|
||||
Create a SquadData object by providing the name of a SQuAD format json file
|
||||
"""
|
||||
data = json.load(open(filename))
|
||||
return cls(data)
|
||||
|
||||
def save(self, filename: str):
|
||||
"""Write the data stored in this object to a json file"""
|
||||
with open(filename, "w") as f:
|
||||
squad_data = {"version": self.version, "data": self.data}
|
||||
json.dump(squad_data, f, indent=2)
|
||||
|
||||
def to_dpr_dataset(self):
|
||||
raise NotImplementedError(
|
||||
"SquadData.to_dpr_dataset() not yet implemented. "
|
||||
"For now, have a look at the script at haystack/retriever/squad_to_dpr.py"
|
||||
)
|
||||
|
||||
def to_document_objs(self):
|
||||
"""Export all paragraphs stored in this object to haystack.Document objects"""
|
||||
df_docs = self.df[["title", "context"]]
|
||||
df_docs = df_docs.drop_duplicates()
|
||||
record_dicts = df_docs.to_dict("records")
|
||||
documents = [
|
||||
Document(
|
||||
text=rd["context"],
|
||||
id=rd["title"]
|
||||
) for rd in record_dicts
|
||||
]
|
||||
return documents
|
||||
|
||||
def to_label_objs(self):
|
||||
"""Export all labels stored in this object to haystack.Label objects"""
|
||||
df_labels = self.df[["id", "question", "answer_text", "answer_start"]]
|
||||
record_dicts = df_labels.to_dict("records")
|
||||
labels = [
|
||||
Label(
|
||||
question=rd["question"],
|
||||
answer=rd["answer_text"],
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
id=rd["id"],
|
||||
origin=rd.get("origin", "SquadData tool"),
|
||||
document_id=rd.get("document_id", None)
|
||||
) for rd in record_dicts
|
||||
]
|
||||
return labels
|
||||
|
||||
@staticmethod
|
||||
def to_df(data):
|
||||
"""Convert a list of SQuAD document dictionaries into a pandas dataframe (each row is one annotation)"""
|
||||
flat = []
|
||||
for document in data:
|
||||
title = document["title"]
|
||||
for paragraph in document["paragraphs"]:
|
||||
context = paragraph["context"]
|
||||
for question in paragraph["qas"]:
|
||||
q = question["question"]
|
||||
id = question["id"]
|
||||
is_impossible = question["is_impossible"]
|
||||
# For no_answer samples
|
||||
if len(question["answers"]) == 0:
|
||||
flat.append({"title": title,
|
||||
"context": context,
|
||||
"question": q,
|
||||
"id": id,
|
||||
"answer_text": "",
|
||||
"answer_start": None,
|
||||
"is_impossible": is_impossible})
|
||||
# For span answer samples
|
||||
else:
|
||||
for answer in question["answers"]:
|
||||
answer_text = answer["text"]
|
||||
answer_start = answer["answer_start"]
|
||||
flat.append({"title": title,
|
||||
"context": context,
|
||||
"question": q,
|
||||
"id": id,
|
||||
"answer_text": answer_text,
|
||||
"answer_start": answer_start,
|
||||
"is_impossible": is_impossible})
|
||||
df = pd.DataFrame.from_records(flat)
|
||||
return df
|
||||
|
||||
def count(self, unit="questions"):
|
||||
"""
|
||||
Count the samples in the data. Choose from unit = "paragraphs", "questions", "answers", "no_answers", "span_answers"
|
||||
"""
|
||||
c = 0
|
||||
for document in self.data:
|
||||
for paragraph in document["paragraphs"]:
|
||||
if unit == "paragraphs":
|
||||
c += 1
|
||||
for question in paragraph["qas"]:
|
||||
if unit == "questions":
|
||||
c += 1
|
||||
# Count no_answers
|
||||
if len(question["answers"]) == 0:
|
||||
if unit in ["answers", "no_answers"]:
|
||||
c += 1
|
||||
# Count span answers
|
||||
else:
|
||||
for answer in question["answers"]:
|
||||
if unit in ["answers", "span_answers"]:
|
||||
c += 1
|
||||
return c
|
||||
|
||||
def df_to_data(self, df):
|
||||
"""Convert a dataframe into SQuAD format data (list of SQuAD document dictionaries)"""
|
||||
|
||||
logger.info("Converting data frame to squad format data")
|
||||
|
||||
# Aggregate the answers of each question
|
||||
logger.info("Aggregating the answers of each question")
|
||||
df_grouped_answers = df.groupby(["title", "context", "question", "id", "is_impossible"])
|
||||
df_aggregated_answers = df[["title", "context", "question", "id", "is_impossible"]].drop_duplicates().reset_index()
|
||||
answers = df_grouped_answers.progress_apply(self._aggregate_answers).rename("answers")
|
||||
answers = pd.DataFrame(answers).reset_index()
|
||||
df_aggregated_answers = pd.merge(df_aggregated_answers, answers)
|
||||
|
||||
# Aggregate the questions of each passage
|
||||
logger.info("Aggregating the questions of each paragraphs of each document")
|
||||
df_grouped_questions = df_aggregated_answers.groupby(["title", "context"])
|
||||
df_aggregated_questions = df[["title", "context"]].drop_duplicates().reset_index()
|
||||
questions = df_grouped_questions.progress_apply(self._aggregate_questions).rename("qas")
|
||||
questions = pd.DataFrame(questions).reset_index()
|
||||
df_aggregated_questions = pd.merge(df_aggregated_questions, questions)
|
||||
|
||||
logger.info("Aggregating the paragraphs of each document")
|
||||
df_grouped_paragraphs = df_aggregated_questions.groupby(["title"])
|
||||
df_aggregated_paragraphs = df[["title"]].drop_duplicates().reset_index()
|
||||
paragraphs = df_grouped_paragraphs.progress_apply(self._aggregate_passages).rename("paragraphs")
|
||||
paragraphs = pd.DataFrame(paragraphs).reset_index()
|
||||
df_aggregated_paragraphs = pd.merge(df_aggregated_paragraphs, paragraphs)
|
||||
|
||||
df_aggregated_paragraphs = df_aggregated_paragraphs[["title", "paragraphs"]]
|
||||
ret = df_aggregated_paragraphs.to_dict("records")
|
||||
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _aggregate_passages(x):
|
||||
x = x[["context", "qas"]]
|
||||
ret = x.to_dict("records")
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _aggregate_questions(x):
|
||||
x = x[["question", "id", "answers", "is_impossible"]]
|
||||
ret = x.to_dict("records")
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _aggregate_answers(x):
|
||||
x = x[["answer_text", "answer_start"]]
|
||||
x = x.rename(columns={"answer_text": "text"})
|
||||
|
||||
# Span anwser
|
||||
try:
|
||||
x["answer_start"] = x["answer_start"].astype(int)
|
||||
ret = x.to_dict("records")
|
||||
|
||||
# No answer
|
||||
except ValueError:
|
||||
ret = []
|
||||
|
||||
return ret
|
||||
|
||||
def set_data(self, data):
|
||||
self.data = data
|
||||
self.df = self.to_df(data)
|
||||
|
||||
def sample_questions(self, n):
|
||||
"""
|
||||
Return a sample of n questions in SQuAD format (list of SQuAD document dictionaries)
|
||||
Note, that if the same question is asked on multiple different passages, this fn treats that
|
||||
as a single question
|
||||
"""
|
||||
all_questions = self.get_all_questions()
|
||||
sampled_questions = random.sample(all_questions, n)
|
||||
df_sampled = self.df[self.df["question"].isin(sampled_questions)]
|
||||
return self.df_to_data(df_sampled)
|
||||
|
||||
def get_all_paragraphs(self):
|
||||
"""Return all paragraph strings"""
|
||||
return self.df["context"].unique().tolist()
|
||||
|
||||
def get_all_questions(self):
|
||||
"""Return all question strings. Note that if the same question appears for different paragraphs, it will be
|
||||
returned multiple times by this fn"""
|
||||
df_questions = self.df[["title", "context", "question"]]
|
||||
df_questions = df_questions.drop_duplicates()
|
||||
questions = df_questions["question"].tolist()
|
||||
return questions
|
||||
|
||||
def get_all_document_titles(self):
|
||||
"""Return all document title strings"""
|
||||
return self.df["title"].unique().tolist()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Download the SQuAD dataset if it isn't at target directory
|
||||
read_squad_file( "../data/squad20/train-v2.0.json")
|
||||
|
||||
filename1 = "../data/squad20/train-v2.0.json"
|
||||
filename2 = "../data/squad20/dev-v2.0.json"
|
||||
|
||||
# Load file1 and take a sample of 10000 questions
|
||||
sd = SquadData.from_file(filename1)
|
||||
sample1 = sd.sample_questions(n=10000)
|
||||
|
||||
# Set sd to now contain the sample of 10000 questions
|
||||
sd.set_data(sample1)
|
||||
|
||||
# Merge sd with file2 and take a sample of 100 questions
|
||||
sd.merge_from_file(filename2)
|
||||
sample2 = sd.sample_questions(n=100)
|
||||
sd.set_data(sample2)
|
||||
|
||||
# Save this sample of 100
|
||||
sd.save("../data/squad20/sample.json")
|
||||
|
||||
paragraphs = sd.get_all_paragraphs()
|
||||
questions = sd.get_all_questions()
|
||||
titles = sd.get_all_document_titles()
|
||||
|
||||
documents = sd.to_document_objs()
|
||||
labels = sd.to_label_objs()
|
||||
|
||||
n_qs = sd.count(unit="questions")
|
||||
n_as = sd.count(unit="no_answers")
|
||||
n_ps = sd.count(unit="paragraphs")
|
||||
|
||||
print(n_qs)
|
||||
print(n_as)
|
||||
print(n_ps)
|
||||
Loading…
x
Reference in New Issue
Block a user