diff --git a/haystack/utils/squad_data.py b/haystack/utils/squad_data.py index de847fd83..f725a7c0a 100644 --- a/haystack/utils/squad_data.py +++ b/haystack/utils/squad_data.py @@ -171,7 +171,8 @@ class SquadData: c += 1 return c - def df_to_data(self, df): + @classmethod + def df_to_data(cls, df): """ Convert a dataframe into SQuAD format data (list of SQuAD document dictionaries). """ @@ -183,7 +184,7 @@ class SquadData: df_aggregated_answers = ( df[["title", "context", "question", "id", "is_impossible"]].drop_duplicates().reset_index() ) - answers = df_grouped_answers.progress_apply(self._aggregate_answers).rename("answers") + answers = df_grouped_answers.progress_apply(cls._aggregate_answers).rename("answers") answers = pd.DataFrame(answers).reset_index() df_aggregated_answers = pd.merge(df_aggregated_answers, answers) @@ -191,14 +192,14 @@ class SquadData: logger.info("Aggregating the questions of each paragraphs of each document") df_grouped_questions = df_aggregated_answers.groupby(["title", "context"]) df_aggregated_questions = df[["title", "context"]].drop_duplicates().reset_index() - questions = df_grouped_questions.progress_apply(self._aggregate_questions).rename("qas") + questions = df_grouped_questions.progress_apply(cls._aggregate_questions).rename("qas") questions = pd.DataFrame(questions).reset_index() df_aggregated_questions = pd.merge(df_aggregated_questions, questions) logger.info("Aggregating the paragraphs of each document") df_grouped_paragraphs = df_aggregated_questions.groupby(["title"]) df_aggregated_paragraphs = df[["title"]].drop_duplicates().reset_index() - paragraphs = df_grouped_paragraphs.progress_apply(self._aggregate_passages).rename("paragraphs") + paragraphs = df_grouped_paragraphs.progress_apply(cls._aggregate_passages).rename("paragraphs") paragraphs = pd.DataFrame(paragraphs).reset_index() df_aggregated_paragraphs = pd.merge(df_aggregated_paragraphs, paragraphs) diff --git a/test/test_utils.py b/test/test_utils.py index c3a5fe385..d6b10d6b7 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,4 +1,5 @@ import pytest +import pandas as pd from pathlib import Path from haystack.utils.preprocessing import convert_files_to_dicts, tika_convert_files_to_dicts @@ -38,3 +39,33 @@ def test_squad_augmentation(): original_squad = SquadData.from_file(input_) augmented_squad = SquadData.from_file(output) assert original_squad.count(unit="paragraph") == augmented_squad.count(unit="paragraph") * multiplication_factor + + +def test_squad_to_df(): + df = pd.DataFrame( + [["title", "context", "question", "id", "answer", 1, False]], + columns=["title", "context", "question", "id", "answer_text", "answer_start", "is_impossible"], + ) + + expected_result = [ + { + "title": "title", + "paragraphs": [ + { + "context": "context", + "qas": [ + { + "question": "question", + "id": "id", + "answers": [{"text": "answer", "answer_start": 1}], + "is_impossible": False, + } + ], + } + ], + } + ] + + result = SquadData.df_to_data(df) + + assert result == expected_result