-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
memory leak when compiling collective + view + wait() #126338
Comments
here is a smaller repro: https://github.com/pytorch-labs/float8_experimental/pull/260/files |
pytorch-labs/float8_experimental#262 is a hotfix for float8, although from talking to @eellison it sounds like we can potentially fix this in inductor by having inductor not convert views into non-views when we know they came directly from a collective |
Summary: I'm going to write a more detailed post internally to explain this memory leak Tracking issue for a better fix in inductor: pytorch/pytorch#126338 Pull Request resolved: #262 Reviewed By: drisspg Differential Revision: D57464230 Pulled By: bdhirsh fbshipit-source-id: 134c50e95045c43f95b5aec4dd3df496ff3fb9a3
We should fix this asap but there is a hotfix already landed.. arguably should still be high pri, will look soon |
I have some confusion regarding the necessity of the kernel # triton_poi_fused_view_3. The code responsible for generating this kernel can be found here. This op |
# Summary Different take on this one: #126338 We should probably not allow this mapping for 'compute' ops e.g. reductions ### Corresponding fp8 PR pytorch-labs/float8_experimental#263 Pull Request resolved: #126556 Approved by: https://github.com/wanchaol
# Summary Different take on this one: pytorch#126338 We should probably not allow this mapping for 'compute' ops e.g. reductions ### Corresponding fp8 PR pytorch-labs/float8_experimental#263 Pull Request resolved: pytorch#126556 Approved by: https://github.com/wanchaol
Vasily made some repro instructions here (the repro is pretty large: you need to run torchtitan with tensor parallel + float8 + compile): https://docs.google.com/document/d/1IxjSQnilDOHHMpjqB8t2vFchMpZiPseDX7Vls73VrNA/edit#heading=h.30hlzuqf8znv
The crux of the problem is that we end up with an ATen FX graph with these nodes:
(full aten graph: P1363866107)
And inductor lowers this into the following:
(full inductor output: P1363866330)
The problem is that:
(1) normally, AsyncCollectiveTensor tries to ensure that every collective op that is issued is immediately followed by a
wait_tensor()
(2) This
Float8
code customizes the collective call here to add an extra view op after the allgather. That view op gets inserted between the collective and thewait_tensor()
(3) Normally this is an ok thing to do: view ops don't read the data of their input, only metadata: so it is acceptable to view the result of a collective before syncing, and then sync after (before performing any compute)
(4) However, inductor lowers the view op (
view.dtype
) into an out-of-place kernel (triton kernel bitcast above,tl.to(tl.float8e5, bitcast=True)
.Inductor then dels the output of the view kernel, but now that the aliasing is removed, we have not del'd the input to the view kernel (the allgather).
I am not entirely sure how this results in a memory leak (in theory, that allgather is still a bound temporary that should go out of scope when the inductor code finishes running) - but doing the bitcast on initialized memory (before the allgather is finished) seems equally bad)After looking at code with @drisspg, we realized that if you issue a collective but never call
wait_tensor()
on the result,c10d
is obligated to leak this memory.cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @muchulee8 @ColinPeppler @amjames @desertfire
The text was updated successfully, but these errors were encountered: