diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index b539ded95..4c77dd64b 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -364,7 +364,10 @@ class DensePassageRetriever(BaseRetriever): self.processor.num_hard_negatives = num_hard_negatives self.processor.num_positives = num_positives - self.model.connect_heads_with_processor(self.processor.tasks, require_labels=True) + if isinstance(self.model, DataParallel): + self.model.module.connect_heads_with_processor(self.processor.tasks, require_labels=True) + else: + self.model.connect_heads_with_processor(self.processor.tasks, require_labels=True) data_silo = DataSilo(processor=self.processor, batch_size=batch_size, distributed=False, max_processes=max_processes) @@ -816,7 +819,10 @@ class TableTextRetriever(BaseRetriever): self.processor.num_hard_negatives = num_hard_negatives self.processor.num_positives = num_positives - self.model.connect_heads_with_processor(self.processor.tasks, require_labels=True) + if isinstance(self.model, DataParallel): + self.model.module.connect_heads_with_processor(self.processor.tasks, require_labels=True) + else: + self.model.connect_heads_with_processor(self.processor.tasks, require_labels=True) data_silo = DataSilo(processor=self.processor, batch_size=batch_size, distributed=False, max_processes=max_processes)