-
Notifications
You must be signed in to change notification settings - Fork 25.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Significant performance degradation with multi-GPU training on newer torch/transformers #30840
Comments
Accelerate isn't the issue. Timings based on my 2x4090: Assume 0.x are accelerate versions
On
On
So you can see that this issue might involve the trainer, however I didn't actually see any changes here as you can tell. In a last ditch effort:
Now we are seeing issues from transformers instead. Narrowing it down further (assuming same
So the issue stems from transformers 4.32.1 + torch 2.0.1 |
I'm not sure it's worth us fixing, since updating your torch version will solve this problem. |
Is there a specific use-case for needing torch 2.0.1 and you can't use a later version? |
Also: One thing I found could affect it by a number of an hr was the temp my GPU was at. If it was cool/a cold start it could be an hr slower. There's lots of variables at play here and what exactly is the cause of your issue I'm unsure of, even after thorough looking |
@muellerzr Thanks a lot for checking. All your tests seem to be in the same ballpark, so I don't think this really reproduces the issue. Also note that the performance seems to be degrading with more number of GPUs, so 2x4090 may not be enough to reproduce it. I can run some more tests on my end, if you have suggestions. Regarding torch version: Unfortunately, the problem (as I have described above) is that recent torch/transformers versions are actually the ones that are slow. Therefore, I cannot just upgrade them to fix the problem. In fact, I actually upgraded the libraries when I noticed the problem. Regarding temperature: I don't think that's the issue. I have tested this on multiple machines and multiple times. Switching the env changes the runtime significantly, so I doubt that temperature is to blame here. Also, the issue here is not ~1 hr worse performance but by a factor of 2 in many cases (~6hrs vs ~12hrs or 66hrs vs 82hrs as in my example above). |
I’ll see if I can get access to an 8-node system to debug. |
BUT that would mean we’re hitting a ton of unnecessary distributed communications somewhere along the line (since it was working before). |
I ran some tests again (all done in fresh envs):
It looks like recent transformers/accelerate versions are only slightly worse when used with |
Let me dig today and see if we have any torch 2.2+ import checks that could differ. |
Do we have any update on this @muellerzr? |
I thought maybe this has something to do with iterable vs map style datasets, so I did the following test but it's the same story. from typing import Iterator
import torch
from transformers import T5ForConditionalGeneration, Trainer, TrainingArguments
from torch.utils.data import IterableDataset, Dataset
class IterableDummyDataset(IterableDataset):
def __iter__(self) -> Iterator:
while True:
yield {
"input_ids": torch.randint(4000, size=(512,)),
"labels": torch.randint(4000, size=(64,)),
}
class MapDummyDataset(Dataset):
def __len__(self):
return 1000
def __getitem__(self, i):
return {
"input_ids": torch.randint(4000, size=(512,)),
"labels": torch.randint(4000, size=(64,)),
}
if __name__ == "__main__":
model = T5ForConditionalGeneration.from_pretrained("google/t5-efficient-small")
dataset = MapDummyDataset()
training_args = TrainingArguments(
output_dir="./output/",
max_steps=1000_000,
per_device_train_batch_size=16,
)
trainer = Trainer(model=model, train_dataset=dataset, args=training_args)
trainer.train() |
This really looks like a torch issue. I have opened an issue here: pytorch/pytorch#127077 |
Thanks for finding and reproducing in torch! Will keep a close eye on it 🤗 |
*Description of changes:* This PR relaxes `torch` and `transformers` versions to allow for older versions that were used during original training. This is needed in light of recent `torch`/`transformers` versions being slower with DDP. Relevant issues (but the problem may be deeper than these): - huggingface/transformers#30840 - pytorch/pytorch#127077 - NVIDIA/nccl#1298 By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. Co-authored-by: Abdul Fatir Ansari <ansarnd@amazon.com>
System Info
Information
Tasks
no_trainer
script in theexamples
folder of thetransformers
repo (such asrun_no_trainer_glue.py
)Reproduction
I am using a
g5.12xlarge
EC2 instance for this test but I observed this issue on other machines as well. This is just a minimum example to demonstrate the issue. In my actual usage, the degradation is even worse.env1
and install:pip install transformers torch accelerate
.env2
and install:pip install transformers==4.30.2 torch==2.0.1 accelerate==0.20.3
.torchrun --nproc-per-node=4 test.py
.Observations
env1
GPU0 utilization keeps fluctuating and the estimated training time is shown as ~82hrs.env2
all GPUs have utilization maxed out and the estimated training time is shown as ~66hrs.Expected behavior
Both environments should have similar training time.
The text was updated successfully, but these errors were encountered: