haystack/test/test_utils.py
MichelBartels e8cd5ea943
Add distillation to finetuning tutorial (#2025)
* Add finetuning tutorial

* Add latest docstring and tutorial changes

* fix typo

* Add latest docstring and tutorial changes

* improve distillation explanation in finetuning tutorial

* Add latest docstring and tutorial changes

* allow augment_squad.py to be easier to call from within python

* Update Tutorial2_Finetune_a_model_on_your_data.py

* fix squad augmentation test

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2022-01-20 12:18:32 +01:00

29 lines
1.4 KiB
Python

import pytest
from pathlib import Path
from haystack.utils.preprocessing import convert_files_to_dicts, tika_convert_files_to_dicts
from haystack.utils.cleaning import clean_wiki_text
from haystack.utils.augment_squad import augment_squad
from haystack.utils.squad_data import SquadData
def test_convert_files_to_dicts():
documents = convert_files_to_dicts(dir_path="samples", clean_func=clean_wiki_text, split_paragraphs=True)
assert documents and len(documents) > 0
@pytest.mark.tika
def test_tika_convert_files_to_dicts():
documents = tika_convert_files_to_dicts(dir_path="samples", clean_func=clean_wiki_text, split_paragraphs=True)
assert documents and len(documents) > 0
def test_squad_augmentation():
input_ = Path("samples/squad/tiny.json")
output = Path("samples/squad/tiny_augmented.json")
glove_path = Path("samples/glove/tiny.txt") # dummy glove file, will not even be use when augmenting tiny.json
multiplication_factor = 5
augment_squad(model="distilbert-base-uncased", tokenizer="distilbert-base-uncased", squad_path=input_, output_path=output,
glove_path=glove_path, multiplication_factor=multiplication_factor)
original_squad = SquadData.from_file(input_)
augmented_squad = SquadData.from_file(output)
assert original_squad.count(unit="paragraph") == augmented_squad.count(unit="paragraph") * multiplication_factor