mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-16 11:43:26 +00:00
Replace torch.device(cuda) with torch.device(cuda:0) in devices initialization (#3184)
This commit is contained in:
parent
20880c9d41
commit
e0d73f3ae0
@ -96,6 +96,12 @@ def initialize_device_settings(
|
|||||||
n_gpu = 1
|
n_gpu = 1
|
||||||
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||||
torch.distributed.init_process_group(backend="nccl")
|
torch.distributed.init_process_group(backend="nccl")
|
||||||
|
|
||||||
|
# HF transformers v4.21.2 pipeline object doesn't accept torch.device("cuda"), it has to be an indexed cuda device
|
||||||
|
# TODO eventually remove once the limitation is fixed in HF transformers
|
||||||
|
device_to_replace = torch.device("cuda")
|
||||||
|
devices_to_use = [torch.device("cuda:0") if device == device_to_replace else device for device in devices_to_use]
|
||||||
|
|
||||||
logger.info(f"Using devices: {', '.join([str(device) for device in devices_to_use]).upper()}")
|
logger.info(f"Using devices: {', '.join([str(device) for device in devices_to_use]).upper()}")
|
||||||
logger.info(f"Number of GPUs: {n_gpu}")
|
logger.info(f"Number of GPUs: {n_gpu}")
|
||||||
return devices_to_use, n_gpu
|
return devices_to_use, n_gpu
|
||||||
|
Loading…
x
Reference in New Issue
Block a user