This commit is contained in:
Alon Eirew 2021-11-09 13:52:07 +02:00 committed by GitHub
parent cd8666f904
commit 861522b6b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -364,7 +364,10 @@ class DensePassageRetriever(BaseRetriever):
self.processor.num_hard_negatives = num_hard_negatives
self.processor.num_positives = num_positives
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=True)
if isinstance(self.model, DataParallel):
self.model.module.connect_heads_with_processor(self.processor.tasks, require_labels=True)
else:
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=True)
data_silo = DataSilo(processor=self.processor, batch_size=batch_size, distributed=False, max_processes=max_processes)
@ -816,7 +819,10 @@ class TableTextRetriever(BaseRetriever):
self.processor.num_hard_negatives = num_hard_negatives
self.processor.num_positives = num_positives
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=True)
if isinstance(self.model, DataParallel):
self.model.module.connect_heads_with_processor(self.processor.tasks, require_labels=True)
else:
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=True)
data_silo = DataSilo(processor=self.processor, batch_size=batch_size, distributed=False,
max_processes=max_processes)