mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-07 21:33:39 +00:00
parent
cd8666f904
commit
861522b6b1
@ -364,6 +364,9 @@ class DensePassageRetriever(BaseRetriever):
|
|||||||
self.processor.num_hard_negatives = num_hard_negatives
|
self.processor.num_hard_negatives = num_hard_negatives
|
||||||
self.processor.num_positives = num_positives
|
self.processor.num_positives = num_positives
|
||||||
|
|
||||||
|
if isinstance(self.model, DataParallel):
|
||||||
|
self.model.module.connect_heads_with_processor(self.processor.tasks, require_labels=True)
|
||||||
|
else:
|
||||||
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=True)
|
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=True)
|
||||||
|
|
||||||
data_silo = DataSilo(processor=self.processor, batch_size=batch_size, distributed=False, max_processes=max_processes)
|
data_silo = DataSilo(processor=self.processor, batch_size=batch_size, distributed=False, max_processes=max_processes)
|
||||||
@ -816,6 +819,9 @@ class TableTextRetriever(BaseRetriever):
|
|||||||
self.processor.num_hard_negatives = num_hard_negatives
|
self.processor.num_hard_negatives = num_hard_negatives
|
||||||
self.processor.num_positives = num_positives
|
self.processor.num_positives = num_positives
|
||||||
|
|
||||||
|
if isinstance(self.model, DataParallel):
|
||||||
|
self.model.module.connect_heads_with_processor(self.processor.tasks, require_labels=True)
|
||||||
|
else:
|
||||||
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=True)
|
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=True)
|
||||||
|
|
||||||
data_silo = DataSilo(processor=self.processor, batch_size=batch_size, distributed=False,
|
data_silo = DataSilo(processor=self.processor, batch_size=batch_size, distributed=False,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user