Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataParallel is used by auto_model with single GPU #2447

Open
H4dr1en opened this issue Feb 2, 2022 · 9 comments
Open

DataParallel is used by auto_model with single GPU #2447

H4dr1en opened this issue Feb 2, 2022 · 9 comments
Labels

Comments

@H4dr1en
Copy link
Contributor

H4dr1en commented Feb 2, 2022

馃悰 Bug description

I am not sure whether it is a bug or a feature:

The DataParallel is being applied/patched by idist.auto_model in the context of a single gpu (backend=None, nproc_per_node=1). What is the reason behind this choice? Does it bring any speed improvements?

The only way to prevent it is to set os.environ["CUDA_VISIBLE_DEVICES"] = "0" for single-gpu contexts.

Environment

  • PyTorch Version (e.g., 1.4): 1.7.1
  • Ignite Version (e.g., 0.3.0): 0.4.8
  • OS (e.g., Linux): Linux
  • How you installed Ignite (conda, pip, source): pip
  • Python version: 3.8
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 2, 2022

@H4dr1en I haven't yet checked the code but I agree that this seems useless to wrap the model by DataParallel in case of 1 device.

EDIT:

elif torch.cuda.device_count() > 1 and "cuda" in idist.device().type:

It takes all available GPUs and wrap the model with DataParallel on all GPUs.

The only way to prevent it is to set os.environ["CUDA_VISIBLE_DEVICES"] = "0" for single-gpu contexts.

Yes, I think this is the most simple way to pick the GPU to use without updating the ignite code and API on how to pick the GPU to use.

I think this is expected behaviour if you have N GPUs, start the script without any restriction on how many GPUs to use and use idist.auto_model.

@H4dr1en
Copy link
Contributor Author

H4dr1en commented Feb 2, 2022

@H4dr1en I haven't yet checked the code but I agree that this seems useless to wrap the model by DataParallel in case of 1 device. [...] I think this is expected behaviour if you have N GPUs, start the script without any restriction on how many GPUs to use and use idist.auto_model.

DataParallel seems to drastically slow down the model training, for a single GPU. I tried with the cifar10 example from the pytorch-ignite/examples repository on this fork, I ran this script on a g4dn.12xlarge AWS instance (x4 T4 GPUs).

The default training time reported by ignite is

CIFAR10-Training INFO: Engine run complete. Time taken: 00:29:53

If I set os.environ["CUDA_VISIBLE_DEVICES"] = "0"(line 430), the training time becomes

CIFAR10-Training INFO: Engine run complete. Time taken: 00:05:34

This is the only change I made. If this is truly coming from the DataParallel being not efficient for a single GPU, I see no reason why it should be automatically set. I'd expect the idist.auto_model to give me the best/fastest configuration, in this case, no DataParallel. Am I missing something?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 2, 2022

Thanks for details @H4dr1en !

First, PyTorch recommends to use DistributedDataParallel (DDP) instead of DataParallel (DP), https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html

If I understand correctly your code, your infrastructure and the way you launched it, then in case of CIFAR10-Training INFO: Engine run complete. Time taken: 00:29:53, you start one process on a machine with 4 GPUs and thus idist.auto_model takes all 4 GPUs and partition every single batch between 4 GPUs to perform forward/backward passes and if I remember correctly DataParallel implementation, the model is replicated on each forward call. So, you train on 4 GPUs (not 1 GPU) with specified batch size. This probably can explain the slow-down.
You can try to make the batch size x4 and compare. But I do not think that DP will be faster then DDP.

When you specify CUDA_VISIBLE_DEVICES=0, you expose only GPU 0 to the script and thus training is on a single GPU without using DataParallel.

I'd expect the idist.auto_model to give me the best/fastest configuration, in this case, no DataParallel.

idist.auto_model picks the best option for the given configuration:

  • if there is a distributed processing group and N GPUs => DDP wrapper
  • if there are N GPUs and no distibuted processing group => DP wrapper, such that all GPUs are used vs use only one GPU
  • otherwise, do not wrap the model.

What do you think ?

@H4dr1en
Copy link
Contributor Author

H4dr1en commented Feb 2, 2022

Thanks for clarifying @vfdev-5 !

Indeed I can observe that I am training on 4 GPUs with DataParallel by default. I understand the logic of idist.auto_model, my confusion comes from the fact that when I set nproc_per_node=1, I expect to have the training on one GPU, but it is actually training on all the GPUs available with DDP, which make it slower. Is it correct?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 2, 2022

Yeah, I was also confused about how you are using nproc_per_node in your code. By the way, I still do not quite understand what happens exactly if nproc_per_node > 1, what happens with the subprocess entering into __main__, config has nproc_per_node=4 by default, so how it works that it does respawn inf number of processes ?

I expect to have the training on one GPU, but it is actually training on all the GPUs available with DDP, which make it slower. Is it correct?

I'd expect training on 4 GPUs in DDP mode (not DP) is faster than 1 GPU. In your case, I suppose the slowdown with nproc_per_node=1 and using 4 GPUs in DP mode is related to a small batch size.

@H4dr1en
Copy link
Contributor Author

H4dr1en commented Feb 2, 2022

By the way, I still do not quite understand what happens exactly if nproc_per_node > 1, what happens with the subprocess entering into main, config has nproc_per_node=4 by default, so how it works that it does respawn inf number of processes ?

The subprocesses have args.local_rank defined, so they just start the training (skip L435)

I'd expect training on 4 GPUs in DDP mode (not DP) is faster than 1 GPU. In your case, I suppose the slowdown with nproc_per_node=1 and using 4 GPUs in DP mode is related to a small batch size.

Yes, on the same page. What I want to raise in this issue is the fact that idist.auto_model, when being executed with an environment where nproc_per_node=1, regardless of the number of GPUs available, should not use DP, and only train using a single GPU - does it make sense?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 2, 2022

What I want to raise in this issue is the fact that idist.auto_model, when being executed with an environment where nproc_per_node=1, regardless of the number of GPUs available, should not use DP, and only train using a single GPU - does it make sense?

Yes, it makes perfectly sense, but how ignite can know that you have nproc_per_node=1 ?

In addition, there can be (old) cases when we would like to use DP : one process and use multiple GPUs.

@H4dr1en
Copy link
Contributor Author

H4dr1en commented Feb 3, 2022

Yes, it makes perfectly sense, but how ignite can know that you have nproc_per_node=1 ?

Can we maybe check the world_size?

In addition, there can be (old) cases when we would like to use DP : one process and use multiple GPUs.

Is it justified to keep this use case, now that DDP is out and is faster than DP? I don't fully understand the different use cases, so I might be wrong, in which case I understand that we should not change this behaviour

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 3, 2022

Can we maybe check the world_size?

yes, we are using world_size to setup DDP. If world_size is defined and >1 then there is a distributed processing group and there is no point to use DP.

Here is the code:

if idist.get_world_size() > 1:
bnd = idist.backend()
if idist.has_native_dist_support and bnd in (idist_native.NCCL, idist_native.GLOO, idist_native.MPI):
if sync_bn:
logger.info("Convert batch norm to sync batch norm")
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
if torch.cuda.is_available():
if "device_ids" in kwargs:
raise ValueError(f"Argument kwargs should not contain 'device_ids', but got {kwargs}")
lrank = idist.get_local_rank()
logger.info(f"Apply torch DistributedDataParallel on model, device id: {lrank}")
kwargs["device_ids"] = [
lrank,
]
else:
logger.info("Apply torch DistributedDataParallel on model")
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
elif idist.has_hvd_support and bnd == idist_hvd.HOROVOD:
import horovod.torch as hvd
logger.info("Broadcast the initial variable states from rank 0 to all other processes")
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
# not distributed but multiple GPUs reachable so data parallel model
elif torch.cuda.device_count() > 1 and "cuda" in idist.device().type:
logger.info("Apply torch DataParallel on model")
model = torch.nn.parallel.DataParallel(model, **kwargs)

If there is no distributed processing group, but we have more then one GPUs available, we can use DP.

To enable distributed processing group, user can specify the backend in idist.Parallel and a group will be automatically created using all available process.

Is it justified to keep this use case, now that DDP is out and is faster than DP? I don't fully understand the different use cases, so I might be wrong, in which case I understand that we should not change this behaviour

In our case we leave the decision to the user. By launching a single process (python main.py) on a machine with N GPUs, he/she can either stay with a single process and use DP or spawn N sub-processes (and ignite internally creates a dist process group and thus auto_model will use DDP).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants