* fixing issue724

* fixing issue724
This commit is contained in:
Xueqing Liu 2022-10-04 10:51:12 -04:00 committed by GitHub
parent 0df54ec3f3
commit ceb3e300cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 2 deletions

View File

@ -6,7 +6,12 @@ from transformers.data.data_collator import (
)
from collections import OrderedDict
from flaml.data import TOKENCLASSIFICATION, MULTICHOICECLASSIFICATION, SUMMARIZATION
from flaml.data import (
TOKENCLASSIFICATION,
MULTICHOICECLASSIFICATION,
SUMMARIZATION,
SEQCLASSIFICATION,
)
@dataclass
@ -45,5 +50,6 @@ task_to_datacollator_class = OrderedDict(
(TOKENCLASSIFICATION, DataCollatorForTokenClassification),
(MULTICHOICECLASSIFICATION, DataCollatorForMultipleChoiceClassification),
(SUMMARIZATION, DataCollatorForSeq2Seq),
(SEQCLASSIFICATION, DataCollatorWithPadding),
]
)

View File

@ -56,7 +56,7 @@ def test_hf_data():
record_id=0,
**automl_settings
)
automl.predict(X_test)
automl.predict(X_test, **{"per_device_eval_batch_size": 2})
automl.predict(["test test", "test test"])
automl.predict(
[