Replace torch.device(cuda) with torch.device(cuda:0) in devices initialization (#3184)

This commit is contained in:
Vladimir Blagojevic 2022-09-08 09:36:38 -04:00 committed by GitHub
parent 20880c9d41
commit e0d73f3ae0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -96,6 +96,12 @@ def initialize_device_settings(
n_gpu = 1 n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend="nccl") torch.distributed.init_process_group(backend="nccl")
# HF transformers v4.21.2 pipeline object doesn't accept torch.device("cuda"), it has to be an indexed cuda device
# TODO eventually remove once the limitation is fixed in HF transformers
device_to_replace = torch.device("cuda")
devices_to_use = [torch.device("cuda:0") if device == device_to_replace else device for device in devices_to_use]
logger.info(f"Using devices: {', '.join([str(device) for device in devices_to_use]).upper()}") logger.info(f"Using devices: {', '.join([str(device) for device in devices_to_use]).upper()}")
logger.info(f"Number of GPUs: {n_gpu}") logger.info(f"Number of GPUs: {n_gpu}")
return devices_to_use, n_gpu return devices_to_use, n_gpu