From ceb3e300cd79eeada7ca449343d4d15ea170a2cd Mon Sep 17 00:00:00 2001 From: Xueqing Liu Date: Tue, 4 Oct 2022 10:51:12 -0400 Subject: [PATCH] Issue724 (#745) * fixing issue724 * fixing issue724 --- flaml/nlp/huggingface/data_collator.py | 8 +++++++- test/nlp/test_autohf.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/flaml/nlp/huggingface/data_collator.py b/flaml/nlp/huggingface/data_collator.py index 7f33dc330..2d10f1520 100644 --- a/flaml/nlp/huggingface/data_collator.py +++ b/flaml/nlp/huggingface/data_collator.py @@ -6,7 +6,12 @@ from transformers.data.data_collator import ( ) from collections import OrderedDict -from flaml.data import TOKENCLASSIFICATION, MULTICHOICECLASSIFICATION, SUMMARIZATION +from flaml.data import ( + TOKENCLASSIFICATION, + MULTICHOICECLASSIFICATION, + SUMMARIZATION, + SEQCLASSIFICATION, +) @dataclass @@ -45,5 +50,6 @@ task_to_datacollator_class = OrderedDict( (TOKENCLASSIFICATION, DataCollatorForTokenClassification), (MULTICHOICECLASSIFICATION, DataCollatorForMultipleChoiceClassification), (SUMMARIZATION, DataCollatorForSeq2Seq), + (SEQCLASSIFICATION, DataCollatorWithPadding), ] ) diff --git a/test/nlp/test_autohf.py b/test/nlp/test_autohf.py index ee0ab693f..f21f02543 100644 --- a/test/nlp/test_autohf.py +++ b/test/nlp/test_autohf.py @@ -56,7 +56,7 @@ def test_hf_data(): record_id=0, **automl_settings ) - automl.predict(X_test) + automl.predict(X_test, **{"per_device_eval_batch_size": 2}) automl.predict(["test test", "test test"]) automl.predict( [