From e0d73f3ae06c0b4841f8ff8be53475caeff1a41e Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 8 Sep 2022 09:36:38 -0400 Subject: [PATCH] Replace torch.device(cuda) with torch.device(cuda:0) in devices initialization (#3184) --- haystack/modeling/utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/haystack/modeling/utils.py b/haystack/modeling/utils.py index 8afb09bef..21aa54038 100644 --- a/haystack/modeling/utils.py +++ b/haystack/modeling/utils.py @@ -96,6 +96,12 @@ def initialize_device_settings( n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend="nccl") + + # HF transformers v4.21.2 pipeline object doesn't accept torch.device("cuda"), it has to be an indexed cuda device + # TODO eventually remove once the limitation is fixed in HF transformers + device_to_replace = torch.device("cuda") + devices_to_use = [torch.device("cuda:0") if device == device_to_replace else device for device in devices_to_use] + logger.info(f"Using devices: {', '.join([str(device) for device in devices_to_use]).upper()}") logger.info(f"Number of GPUs: {n_gpu}") return devices_to_use, n_gpu