Simplify SQuAD data to df conversion (#2124)

* Conversion to df does not need initialization

* Apply Black

* fix test case

* Apply Black

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
mathislucka 2022-02-04 12:37:56 +01:00 committed by GitHub
parent 53decdcefb
commit 34f9308e1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 4 deletions

View File

@ -171,7 +171,8 @@ class SquadData:
c += 1
return c
def df_to_data(self, df):
@classmethod
def df_to_data(cls, df):
"""
Convert a dataframe into SQuAD format data (list of SQuAD document dictionaries).
"""
@ -183,7 +184,7 @@ class SquadData:
df_aggregated_answers = (
df[["title", "context", "question", "id", "is_impossible"]].drop_duplicates().reset_index()
)
answers = df_grouped_answers.progress_apply(self._aggregate_answers).rename("answers")
answers = df_grouped_answers.progress_apply(cls._aggregate_answers).rename("answers")
answers = pd.DataFrame(answers).reset_index()
df_aggregated_answers = pd.merge(df_aggregated_answers, answers)
@ -191,14 +192,14 @@ class SquadData:
logger.info("Aggregating the questions of each paragraphs of each document")
df_grouped_questions = df_aggregated_answers.groupby(["title", "context"])
df_aggregated_questions = df[["title", "context"]].drop_duplicates().reset_index()
questions = df_grouped_questions.progress_apply(self._aggregate_questions).rename("qas")
questions = df_grouped_questions.progress_apply(cls._aggregate_questions).rename("qas")
questions = pd.DataFrame(questions).reset_index()
df_aggregated_questions = pd.merge(df_aggregated_questions, questions)
logger.info("Aggregating the paragraphs of each document")
df_grouped_paragraphs = df_aggregated_questions.groupby(["title"])
df_aggregated_paragraphs = df[["title"]].drop_duplicates().reset_index()
paragraphs = df_grouped_paragraphs.progress_apply(self._aggregate_passages).rename("paragraphs")
paragraphs = df_grouped_paragraphs.progress_apply(cls._aggregate_passages).rename("paragraphs")
paragraphs = pd.DataFrame(paragraphs).reset_index()
df_aggregated_paragraphs = pd.merge(df_aggregated_paragraphs, paragraphs)

View File

@ -1,4 +1,5 @@
import pytest
import pandas as pd
from pathlib import Path
from haystack.utils.preprocessing import convert_files_to_dicts, tika_convert_files_to_dicts
@ -38,3 +39,33 @@ def test_squad_augmentation():
original_squad = SquadData.from_file(input_)
augmented_squad = SquadData.from_file(output)
assert original_squad.count(unit="paragraph") == augmented_squad.count(unit="paragraph") * multiplication_factor
def test_squad_to_df():
df = pd.DataFrame(
[["title", "context", "question", "id", "answer", 1, False]],
columns=["title", "context", "question", "id", "answer_text", "answer_start", "is_impossible"],
)
expected_result = [
{
"title": "title",
"paragraphs": [
{
"context": "context",
"qas": [
{
"question": "question",
"id": "id",
"answers": [{"text": "answer", "answer_start": 1}],
"is_impossible": False,
}
],
}
],
}
]
result = SquadData.df_to_data(df)
assert result == expected_result