diff --git a/haystack/modeling/utils.py b/haystack/modeling/utils.py index 8afb09bef..21aa54038 100644 --- a/haystack/modeling/utils.py +++ b/haystack/modeling/utils.py @@ -96,6 +96,12 @@ def initialize_device_settings( n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend="nccl") + + # HF transformers v4.21.2 pipeline object doesn't accept torch.device("cuda"), it has to be an indexed cuda device + # TODO eventually remove once the limitation is fixed in HF transformers + device_to_replace = torch.device("cuda") + devices_to_use = [torch.device("cuda:0") if device == device_to_replace else device for device in devices_to_use] + logger.info(f"Using devices: {', '.join([str(device) for device in devices_to_use]).upper()}") logger.info(f"Number of GPUs: {n_gpu}") return devices_to_use, n_gpu