-
Notifications
You must be signed in to change notification settings - Fork 3.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] deepspeed overlap_comm data race #5545
Comments
Additional context deepspeed==0.14.2
|
github-merge-queue bot
pushed a commit
that referenced
this issue
Jun 10, 2024
`deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer.average_tensor` only sets reduction stream waiting for default stream. This is ok in cases where the computation time is longer than the communication time, but when the communication time is longer, it may result in a rewrite of the ipg_buffer when the communication is not completed. ![image](https://github.com/microsoft/DeepSpeed/assets/35059704/950cbf8a-f439-4cf9-a364-dcdfd47f46a0) To fix this bug, the easiest way is just add default stream to wait for reduction stream at the **same point**. For example, in point 1, the `reduction stream` needs to wait for '2', so we add a wait_stream to `reduction stream` waiting for `default stream`. Also, the `default stream` needs to wait for 'A', so we need to add a wait_stream to `default stream` waiting for `reduction stream` before the 'B'. ![image](https://github.com/microsoft/DeepSpeed/assets/35059704/588a9469-d3f9-4c39-976d-3ae0502cf1d1) Compared with the modification of #5523, wait_stream does not cause host synchronization. Compared with the modification of #5545, the modification is more simple and the logic is the same, just waiting for what needs to wait. --- With this modification, losses of Qwen-1.5 with and without overlap_comm are totally identical. ![image](https://github.com/microsoft/DeepSpeed/assets/35059704/4d48d54e-e55b-4230-8b99-93549910a43f) --- On the contrary, there is an obvious gap with a small sequence length, which means a short computation time. ![image](https://github.com/microsoft/DeepSpeed/assets/35059704/c80af498-3358-4e36-9b13-8f266551d51d) Co-authored-by: gp513 <guopeng34@huawei.com> Co-authored-by: CurryRice233 <nmeia@qq.com> Co-authored-by: Joe Mayer <114769929+jomayeri@users.noreply.github.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
sfc-gh-reyazda
pushed a commit
to Snowflake-Labs/DeepSpeed
that referenced
this issue
Jun 10, 2024
`deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer.average_tensor` only sets reduction stream waiting for default stream. This is ok in cases where the computation time is longer than the communication time, but when the communication time is longer, it may result in a rewrite of the ipg_buffer when the communication is not completed. ![image](https://github.com/microsoft/DeepSpeed/assets/35059704/950cbf8a-f439-4cf9-a364-dcdfd47f46a0) To fix this bug, the easiest way is just add default stream to wait for reduction stream at the **same point**. For example, in point 1, the `reduction stream` needs to wait for '2', so we add a wait_stream to `reduction stream` waiting for `default stream`. Also, the `default stream` needs to wait for 'A', so we need to add a wait_stream to `default stream` waiting for `reduction stream` before the 'B'. ![image](https://github.com/microsoft/DeepSpeed/assets/35059704/588a9469-d3f9-4c39-976d-3ae0502cf1d1) Compared with the modification of microsoft#5523, wait_stream does not cause host synchronization. Compared with the modification of microsoft#5545, the modification is more simple and the logic is the same, just waiting for what needs to wait. --- With this modification, losses of Qwen-1.5 with and without overlap_comm are totally identical. ![image](https://github.com/microsoft/DeepSpeed/assets/35059704/4d48d54e-e55b-4230-8b99-93549910a43f) --- On the contrary, there is an obvious gap with a small sequence length, which means a short computation time. ![image](https://github.com/microsoft/DeepSpeed/assets/35059704/c80af498-3358-4e36-9b13-8f266551d51d) Co-authored-by: gp513 <guopeng34@huawei.com> Co-authored-by: CurryRice233 <nmeia@qq.com> Co-authored-by: Joe Mayer <114769929+jomayeri@users.noreply.github.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Describe the bug
As illustrated below,DeepSpeed's overlap buffer design presents potential data race.
I have write a patch for bugfix.
Could you kindly help diagnosing and fix this issue?
To Reproduce
debug.py
Unexpected behavior
System info (please complete the following information):
Launcher context
dp_zero1_fp16.yaml
debug.sh
Additional context
A similar issue is present in stage3 as well, yet I have not prepared the patch for it.
The text was updated successfully, but these errors were encountered: