diff --git a/haystack/utils/squad_data.py b/haystack/utils/squad_data.py index a4239b713..7f5931d5f 100644 --- a/haystack/utils/squad_data.py +++ b/haystack/utils/squad_data.py @@ -7,7 +7,7 @@ import pandas as pd from tqdm import tqdm import mmh3 -from haystack.schema import Document, Label +from haystack.schema import Document, Label, Answer from haystack.modeling.data_handler.processor import _read_squad_file @@ -84,24 +84,21 @@ class SquadData: documents = [Document(content=rd["context"], id=rd["title"]) for rd in record_dicts] return documents - # FIXME currently broken! Refactor to new Label objects - def to_label_objs(self): - """ - Export all labels stored in this object to haystack.Label objects. - """ - df_labels = self.df[["id", "question", "answer_text", "answer_start"]] + def to_label_objs(self, answer_type="generative"): + """Export all labels stored in this object to haystack.Label objects""" + df_labels = self.df[["id", "question", "answer_text", "answer_start", "context", "document_id"]] record_dicts = df_labels.to_dict("records") labels = [ - Label( # pylint: disable=no-value-for-parameter - query=rd["question"], - answer=rd["answer_text"], + Label( + query=record["question"], + answer=Answer(answer=record["answer_text"], answer_type=answer_type), is_correct_answer=True, is_correct_document=True, - id=rd["id"], - origin=rd.get("origin", "SquadData tool"), - document_id=rd.get("document_id", None), + id=record["id"], + origin=record.get("origin", "gold-label"), + document=Document(content=record.get("context"), id=str(record["document_id"])), ) - for rd in record_dicts + for record in record_dicts ] return labels @@ -117,7 +114,7 @@ class SquadData: for question in paragraph["qas"]: q = question["question"] id = question["id"] - is_impossible = question["is_impossible"] + is_impossible = question.get("is_impossible", False) # For no_answer samples if len(question["answers"]) == 0: flat.append( diff --git a/test/others/test_squad_data.py b/test/others/test_squad_data.py new file mode 100644 index 000000000..cce21c77b --- /dev/null +++ b/test/others/test_squad_data.py @@ -0,0 +1,109 @@ +import pandas as pd +from haystack.utils.squad_data import SquadData +from haystack.utils.augment_squad import augment_squad +from ..conftest import SAMPLES_PATH +from haystack.schema import Document, Label, Answer + + +def test_squad_augmentation(): + input_ = SAMPLES_PATH / "squad" / "tiny.json" + output = SAMPLES_PATH / "squad" / "tiny_augmented.json" + glove_path = SAMPLES_PATH / "glove" / "tiny.txt" # dummy glove file, will not even be use when augmenting tiny.json + multiplication_factor = 5 + augment_squad( + model="distilbert-base-uncased", + tokenizer="distilbert-base-uncased", + squad_path=input_, + output_path=output, + glove_path=glove_path, + multiplication_factor=multiplication_factor, + ) + original_squad = SquadData.from_file(input_) + augmented_squad = SquadData.from_file(output) + assert original_squad.count(unit="paragraph") == augmented_squad.count(unit="paragraph") * multiplication_factor + + +def test_squad_to_df(): + df = pd.DataFrame( + [["title", "context", "question", "id", "answer", 1, False]], + columns=["title", "context", "question", "id", "answer_text", "answer_start", "is_impossible"], + ) + expected_result = [ + { + "title": "title", + "paragraphs": [ + { + "context": "context", + "qas": [ + { + "question": "question", + "id": "id", + "answers": [{"text": "answer", "answer_start": 1}], + "is_impossible": False, + } + ], + } + ], + } + ] + + result = SquadData.df_to_data(df) + + assert result == expected_result + + +def test_to_label_object(): + squad_data_list = [ + { + "title": "title", + "paragraphs": [ + { + "context": "context", + "qas": [ + { + "question": "question", + "id": "id", + "answers": [{"text": "answer", "answer_start": 1}], + "is_impossible": False, + }, + { + "question": "another question", + "id": "another_id", + "answers": [{"text": "this is the response", "answer_start": 1}], + "is_impossible": False, + }, + ], + }, + { + "context": "the second paragraph context", + "qas": [ + { + "question": "the third question", + "id": "id_3", + "answers": [{"text": "this is another response", "answer_start": 1}], + "is_impossible": False, + }, + { + "question": "the forth question", + "id": "id_4", + "answers": [{"text": "this is the response", "answer_start": 1}], + "is_impossible": False, + }, + ], + }, + ], + } + ] + squad_data = SquadData(squad_data=squad_data_list) + answer_type = "generative" + labels = squad_data.to_label_objs(answer_type=answer_type) + for label, expected_question in zip(labels, squad_data.df.iterrows()): + expected_question = expected_question[1] + assert isinstance(label, Label) + assert isinstance(label.document, Document) + assert isinstance(label.answer, Answer) + assert label.query == expected_question["question"] + assert label.document.content == expected_question.context + assert label.document.id == expected_question.document_id + assert label.id == expected_question.id + assert label.answer.answer == expected_question.answer_text diff --git a/test/others/test_utils.py b/test/others/test_utils.py index f1d523214..fb205041a 100644 --- a/test/others/test_utils.py +++ b/test/others/test_utils.py @@ -4,18 +4,14 @@ from random import random import numpy as np import pytest import pandas as pd - import responses from responses import matchers from haystack.errors import OpenAIRateLimitError from haystack.utils.deepsetcloud import DeepsetCloud, DeepsetCloudExperiments - from haystack.utils.preprocessing import convert_files_to_docs, tika_convert_files_to_docs from haystack.utils.cleaning import clean_wiki_text -from haystack.utils.augment_squad import augment_squad from haystack.utils.reflection import retry_with_exponential_backoff -from haystack.utils.squad_data import SquadData from haystack.utils.context_matching import calculate_context_similarity, match_context, match_contexts from ..conftest import DC_API_ENDPOINT, DC_API_KEY, MOCK_DC, SAMPLES_PATH, deepset_cloud_fixture @@ -52,54 +48,6 @@ def test_tika_convert_files_to_docs(): assert documents and len(documents) > 0 -def test_squad_augmentation(): - input_ = SAMPLES_PATH / "squad" / "tiny.json" - output = SAMPLES_PATH / "squad" / "tiny_augmented.json" - glove_path = SAMPLES_PATH / "glove" / "tiny.txt" # dummy glove file, will not even be use when augmenting tiny.json - multiplication_factor = 5 - augment_squad( - model="distilbert-base-uncased", - tokenizer="distilbert-base-uncased", - squad_path=input_, - output_path=output, - glove_path=glove_path, - multiplication_factor=multiplication_factor, - ) - original_squad = SquadData.from_file(input_) - augmented_squad = SquadData.from_file(output) - assert original_squad.count(unit="paragraph") == augmented_squad.count(unit="paragraph") * multiplication_factor - - -def test_squad_to_df(): - df = pd.DataFrame( - [["title", "context", "question", "id", "answer", 1, False]], - columns=["title", "context", "question", "id", "answer_text", "answer_start", "is_impossible"], - ) - - expected_result = [ - { - "title": "title", - "paragraphs": [ - { - "context": "context", - "qas": [ - { - "question": "question", - "id": "id", - "answers": [{"text": "answer", "answer_start": 1}], - "is_impossible": False, - } - ], - } - ], - } - ] - - result = SquadData.df_to_data(df) - - assert result == expected_result - - def test_calculate_context_similarity_on_parts_of_whole_document(): whole_document = TEST_CONTEXT min_length = 100