You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I train model (Qwen 14B) using DeepSpeed, I find that the results of two tests were significantly inconsistent. first:
`
{'loss': 2.0, 'grad_norm': 20.78711275277921, 'learning_rate': 9.330127018922195e-06, 'epoch': 0.17}
def average_tensor(self, tensor):
if self.overlap_comm:
stream = self.reduction_stream
if not get_accelerator().is_synchronized_device():
stream.wait_stream(get_accelerator().current_stream())
stream.synchronize() # **force synchronize the allreduce stream**
else:
stream = get_accelerator().current_stream()
with get_accelerator().stream(stream):
if not self.reduce_scatter:
self.gradient_reduction_w_predivide(tensor)
return
`
The results remain consistent in three times:
`
{'loss': 2.0, 'grad_norm': 20.78732614100428, 'learning_rate': 9.330127018922195e-06, 'epoch': 0.17}
Whether the double buffer and overlap mechanism is correct?
Now let's consider an extreme case,the reduction_stream computes slow and the current_stream computes fast.
When the reduction_stream has not finished, the second self.ipg_buffer has begin. Will the data in the self.ipg_buffer be overwritten? If it gets overwritten, the calculation will be incorrect.
If we add the synchronization, the self.ipg_buffer operations and reduction_stream communication will be interleaved, thereby ensuring that the buffer will not be overwritten.
Sorry I don't think it is the correct fix. forcing stream.synchronize() meaning the corresponding nccl call will be blocking call and not overlap with subsequent compute.
I don't think so. Even so (forcing stream.synchronize()), there is also overlap because we have two buff.
The backward process can overlap with the communication.
Describe the bug
When I train model (Qwen 14B) using DeepSpeed, I find that the results of two tests were significantly inconsistent.
first:
`
{'loss': 2.0, 'grad_norm': 20.78711275277921, 'learning_rate': 9.330127018922195e-06, 'epoch': 0.17}
{'loss': 1.5547, 'grad_norm': 59.310739679010915, 'learning_rate': 7.500000000000001e-06, 'epoch': 0.33}
{'loss': 0.5781, 'grad_norm': 15.005973828390784, 'learning_rate': 5e-06, 'epoch': 0.5}
{'loss': 0.3184, 'grad_norm': 9.697505381713714, 'learning_rate': 2.5000000000000015e-06, 'epoch': 0.67}
{'loss': 0.1318, 'grad_norm': 6.17889934461755, 'learning_rate': 6.698729810778065e-07, 'epoch': 0.83}
{'loss': 0.0859, 'grad_norm': 4.75770403827632, 'learning_rate': 0.0, 'epoch': 1.0}
`
second:
`
{'loss': 2.0, 'grad_norm': 20.78707942800991, 'learning_rate': 9.330127018922195e-06, 'epoch': 0.17}
{'loss': 1.5547, 'grad_norm': 43.96039229387614, 'learning_rate': 7.500000000000001e-06, 'epoch': 0.33}
{'loss': 0.6484, 'grad_norm': 12.841190229236128, 'learning_rate': 5e-06, 'epoch': 0.5}
{'loss': 0.3105, 'grad_norm': 9.612021710541004, 'learning_rate': 2.5000000000000015e-06, 'epoch': 0.67}
{'loss': 0.1172, 'grad_norm': 5.885649690333212, 'learning_rate': 6.698729810778065e-07, 'epoch': 0.83}
{'loss': 0.0713, 'grad_norm': 4.291706145544393, 'learning_rate': 0.0, 'epoch': 1.0}
`
We can find that the grad_norm is different at step 2, and the loss is different at step 3.
If I add a synchronization in the code below, the results will remain consistent (I've tested it three times).
DeepSpeed/deepspeed/runtime/zero/stage_1_and_2.py
Line 965 in 462def4
`
`
The results remain consistent in three times:
`
{'loss': 2.0, 'grad_norm': 20.78732614100428, 'learning_rate': 9.330127018922195e-06, 'epoch': 0.17}
{'loss': 1.5469, 'grad_norm': 16.731201175437484, 'learning_rate': 7.500000000000001e-06, 'epoch': 0.33}
{'loss': 0.5586, 'grad_norm': 14.621543271989035, 'learning_rate': 5e-06, 'epoch': 0.5}
{'loss': 0.3066, 'grad_norm': 9.533331203714019, 'learning_rate': 2.5000000000000015e-06, 'epoch': 0.67}
{'loss': 0.1226, 'grad_norm': 5.927102870076524, 'learning_rate': 6.698729810778065e-07, 'epoch': 0.83}
{'loss': 0.0796, 'grad_norm': 4.49918771613179, 'learning_rate': 0.0, 'epoch': 1.0}
`
Analysis:
Whether the double buffer and overlap mechanism is correct?
Now let's consider an extreme case,the reduction_stream computes slow and the current_stream computes fast.
When the reduction_stream has not finished, the second self.ipg_buffer has begin. Will the data in the self.ipg_buffer be overwritten? If it gets overwritten, the calculation will be incorrect.
If we add the synchronization, the self.ipg_buffer operations and reduction_stream communication will be interleaved, thereby ensuring that the buffer will not be overwritten.
To Reproduce
Model: Qwen 14b (https://huggingface.co/Qwen/Qwen-14B)
deepspeed : 0.10.2
deepspeed config:
{ "train_micro_batch_size_per_gpu": "auto", "bf16": { "enabled": "auto" }, "fp16": { "enabled": "auto", "loss_scale": 0, "initial_scale_power": 16, "loss_scale_window": 1000, "hysteresis": 2, "min_loss_scale": 1 }, "zero_optimization": { "stage": 2, "allgather_partitions": true, "allgather_bucket_size": 1e9, "overlap_comm": true, "reduce_scatter": false, "reduce_bucket_size": 5e8, "contiguous_gradients": true } }
The text was updated successfully, but these errors were encountered: