mirror of
https://github.com/microsoft/autogen.git
synced 2026-01-06 04:01:12 +00:00
parent
0df54ec3f3
commit
ceb3e300cd
@ -6,7 +6,12 @@ from transformers.data.data_collator import (
|
||||
)
|
||||
from collections import OrderedDict
|
||||
|
||||
from flaml.data import TOKENCLASSIFICATION, MULTICHOICECLASSIFICATION, SUMMARIZATION
|
||||
from flaml.data import (
|
||||
TOKENCLASSIFICATION,
|
||||
MULTICHOICECLASSIFICATION,
|
||||
SUMMARIZATION,
|
||||
SEQCLASSIFICATION,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -45,5 +50,6 @@ task_to_datacollator_class = OrderedDict(
|
||||
(TOKENCLASSIFICATION, DataCollatorForTokenClassification),
|
||||
(MULTICHOICECLASSIFICATION, DataCollatorForMultipleChoiceClassification),
|
||||
(SUMMARIZATION, DataCollatorForSeq2Seq),
|
||||
(SEQCLASSIFICATION, DataCollatorWithPadding),
|
||||
]
|
||||
)
|
||||
|
||||
@ -56,7 +56,7 @@ def test_hf_data():
|
||||
record_id=0,
|
||||
**automl_settings
|
||||
)
|
||||
automl.predict(X_test)
|
||||
automl.predict(X_test, **{"per_device_eval_batch_size": 2})
|
||||
automl.predict(["test test", "test test"])
|
||||
automl.predict(
|
||||
[
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user