mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-24 08:28:22 +00:00
Simplify SQuAD data to df conversion (#2124)
* Conversion to df does not need initialization * Apply Black * fix test case * Apply Black Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
53decdcefb
commit
34f9308e1a
@ -171,7 +171,8 @@ class SquadData:
|
|||||||
c += 1
|
c += 1
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def df_to_data(self, df):
|
@classmethod
|
||||||
|
def df_to_data(cls, df):
|
||||||
"""
|
"""
|
||||||
Convert a dataframe into SQuAD format data (list of SQuAD document dictionaries).
|
Convert a dataframe into SQuAD format data (list of SQuAD document dictionaries).
|
||||||
"""
|
"""
|
||||||
@ -183,7 +184,7 @@ class SquadData:
|
|||||||
df_aggregated_answers = (
|
df_aggregated_answers = (
|
||||||
df[["title", "context", "question", "id", "is_impossible"]].drop_duplicates().reset_index()
|
df[["title", "context", "question", "id", "is_impossible"]].drop_duplicates().reset_index()
|
||||||
)
|
)
|
||||||
answers = df_grouped_answers.progress_apply(self._aggregate_answers).rename("answers")
|
answers = df_grouped_answers.progress_apply(cls._aggregate_answers).rename("answers")
|
||||||
answers = pd.DataFrame(answers).reset_index()
|
answers = pd.DataFrame(answers).reset_index()
|
||||||
df_aggregated_answers = pd.merge(df_aggregated_answers, answers)
|
df_aggregated_answers = pd.merge(df_aggregated_answers, answers)
|
||||||
|
|
||||||
@ -191,14 +192,14 @@ class SquadData:
|
|||||||
logger.info("Aggregating the questions of each paragraphs of each document")
|
logger.info("Aggregating the questions of each paragraphs of each document")
|
||||||
df_grouped_questions = df_aggregated_answers.groupby(["title", "context"])
|
df_grouped_questions = df_aggregated_answers.groupby(["title", "context"])
|
||||||
df_aggregated_questions = df[["title", "context"]].drop_duplicates().reset_index()
|
df_aggregated_questions = df[["title", "context"]].drop_duplicates().reset_index()
|
||||||
questions = df_grouped_questions.progress_apply(self._aggregate_questions).rename("qas")
|
questions = df_grouped_questions.progress_apply(cls._aggregate_questions).rename("qas")
|
||||||
questions = pd.DataFrame(questions).reset_index()
|
questions = pd.DataFrame(questions).reset_index()
|
||||||
df_aggregated_questions = pd.merge(df_aggregated_questions, questions)
|
df_aggregated_questions = pd.merge(df_aggregated_questions, questions)
|
||||||
|
|
||||||
logger.info("Aggregating the paragraphs of each document")
|
logger.info("Aggregating the paragraphs of each document")
|
||||||
df_grouped_paragraphs = df_aggregated_questions.groupby(["title"])
|
df_grouped_paragraphs = df_aggregated_questions.groupby(["title"])
|
||||||
df_aggregated_paragraphs = df[["title"]].drop_duplicates().reset_index()
|
df_aggregated_paragraphs = df[["title"]].drop_duplicates().reset_index()
|
||||||
paragraphs = df_grouped_paragraphs.progress_apply(self._aggregate_passages).rename("paragraphs")
|
paragraphs = df_grouped_paragraphs.progress_apply(cls._aggregate_passages).rename("paragraphs")
|
||||||
paragraphs = pd.DataFrame(paragraphs).reset_index()
|
paragraphs = pd.DataFrame(paragraphs).reset_index()
|
||||||
df_aggregated_paragraphs = pd.merge(df_aggregated_paragraphs, paragraphs)
|
df_aggregated_paragraphs = pd.merge(df_aggregated_paragraphs, paragraphs)
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from haystack.utils.preprocessing import convert_files_to_dicts, tika_convert_files_to_dicts
|
from haystack.utils.preprocessing import convert_files_to_dicts, tika_convert_files_to_dicts
|
||||||
@ -38,3 +39,33 @@ def test_squad_augmentation():
|
|||||||
original_squad = SquadData.from_file(input_)
|
original_squad = SquadData.from_file(input_)
|
||||||
augmented_squad = SquadData.from_file(output)
|
augmented_squad = SquadData.from_file(output)
|
||||||
assert original_squad.count(unit="paragraph") == augmented_squad.count(unit="paragraph") * multiplication_factor
|
assert original_squad.count(unit="paragraph") == augmented_squad.count(unit="paragraph") * multiplication_factor
|
||||||
|
|
||||||
|
|
||||||
|
def test_squad_to_df():
|
||||||
|
df = pd.DataFrame(
|
||||||
|
[["title", "context", "question", "id", "answer", 1, False]],
|
||||||
|
columns=["title", "context", "question", "id", "answer_text", "answer_start", "is_impossible"],
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_result = [
|
||||||
|
{
|
||||||
|
"title": "title",
|
||||||
|
"paragraphs": [
|
||||||
|
{
|
||||||
|
"context": "context",
|
||||||
|
"qas": [
|
||||||
|
{
|
||||||
|
"question": "question",
|
||||||
|
"id": "id",
|
||||||
|
"answers": [{"text": "answer", "answer_start": 1}],
|
||||||
|
"is_impossible": False,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
result = SquadData.df_to_data(df)
|
||||||
|
|
||||||
|
assert result == expected_result
|
||||||
|
Loading…
x
Reference in New Issue
Block a user