Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

memory leak when compiling collective + view + wait() #126338

Open
bdhirsh opened this issue May 15, 2024 · 4 comments
Open

memory leak when compiling collective + view + wait() #126338

bdhirsh opened this issue May 15, 2024 · 4 comments
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@bdhirsh
Copy link
Contributor

bdhirsh commented May 15, 2024

Vasily made some repro instructions here (the repro is pretty large: you need to run torchtitan with tensor parallel + float8 + compile): https://docs.google.com/document/d/1IxjSQnilDOHHMpjqB8t2vFchMpZiPseDX7Vls73VrNA/edit#heading=h.30hlzuqf8znv

The crux of the problem is that we end up with an ATen FX graph with these nodes:

all_gather_into_tensor_2: "u8[4, 2048, 4096]" = torch.ops._c10d_functional.all_gather_into_tensor.default(view_109, 4, '5');  view_109 = None
view_110: "f8e5m2[4, 2048, 4096]" = torch.ops.aten.view.dtype(all_gather_into_tensor_2, torch.float8_e5m2);  all_gather_into_tensor_2 = None
wait_tensor_18: "f8e5m2[4, 2048, 4096]" = torch.ops._c10d_functional.wait_tensor.default(view_110);  view_110 = None

(full aten graph: P1363866107)

And inductor lowers this into the following:

        buf7 = torch.ops._c10d_functional.all_gather_into_tensor.default(buf6, 4, '5')
        assert_size_stride(buf7, (4, 2048, 4096), (8388608, 4096, 1))
        buf8 = empty_strided_cuda((4, 2048, 4096), (8388608, 4096, 1), torch.float8_e5m2)
        # Source Nodes: [], Original ATen: [aten.view]
        triton_poi_fused_view_3.run(buf7, buf8, 33554432, grid=grid(33554432), stream=stream0)
        del buf7
        buf9 = torch.ops._c10d_functional.wait_tensor.default(buf8)
# triton_poi_fused_view_3
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 33554432
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), None)
    tmp1 = tmp0.to(tl.float8e5, bitcast=True)
    tl.store(out_ptr0 + (x0), tmp1, None)

(full inductor output: P1363866330)

The problem is that:

(1) normally, AsyncCollectiveTensor tries to ensure that every collective op that is issued is immediately followed by a wait_tensor()

(2) This Float8 code customizes the collective call here to add an extra view op after the allgather. That view op gets inserted between the collective and the wait_tensor()

(3) Normally this is an ok thing to do: view ops don't read the data of their input, only metadata: so it is acceptable to view the result of a collective before syncing, and then sync after (before performing any compute)

(4) However, inductor lowers the view op (view.dtype) into an out-of-place kernel (triton kernel bitcast above, tl.to(tl.float8e5, bitcast=True).

Inductor then dels the output of the view kernel, but now that the aliasing is removed, we have not del'd the input to the view kernel (the allgather). I am not entirely sure how this results in a memory leak (in theory, that allgather is still a bound temporary that should go out of scope when the inductor code finishes running) - but doing the bitcast on initialized memory (before the allgather is finished) seems equally bad)

After looking at code with @drisspg, we realized that if you issue a collective but never call wait_tensor() on the result, c10d is obligated to leak this memory.

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @muchulee8 @ColinPeppler @amjames @desertfire

@vkuzo
Copy link
Contributor

vkuzo commented May 15, 2024

@bdhirsh
Copy link
Contributor Author

bdhirsh commented May 16, 2024

pytorch-labs/float8_experimental#262 is a hotfix for float8, although from talking to @eellison it sounds like we can potentially fix this in inductor by having inductor not convert views into non-views when we know they came directly from a collective

facebook-github-bot pushed a commit to pytorch-labs/float8_experimental that referenced this issue May 16, 2024
Summary:
I'm going to write a more detailed post internally to explain this memory leak

Tracking issue for a better fix in inductor: pytorch/pytorch#126338

Pull Request resolved: #262

Reviewed By: drisspg

Differential Revision: D57464230

Pulled By: bdhirsh

fbshipit-source-id: 134c50e95045c43f95b5aec4dd3df496ff3fb9a3
@eellison
Copy link
Contributor

We should fix this asap but there is a hotfix already landed.. arguably should still be high pri, will look soon

@drisspg
Copy link
Contributor

drisspg commented May 17, 2024

I have some confusion regarding the necessity of the kernel # triton_poi_fused_view_3. The code responsible for generating this kernel can be found here.

This optorch.ops.aten.view.dtype(all_gather_into_tensor_2, torch.float8_e5m2) is a view but it should be a no-op since it merely converts between a uint8 and an fp8 dtype, both of which have the same bit width. This code was used to workaround the fact that nccl doesn't currently support comms for fp8 dtype (or at least our bindings to nccl in pytorch). Perhaps we need to allow fp8 dtypes to all_gather and to the uint trick down there?

@xmfan xmfan added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 17, 2024
pytorchmergebot pushed a commit that referenced this issue May 18, 2024
# Summary
Different take on this one:
#126338

We should probably not allow this mapping for 'compute' ops e.g. reductions

### Corresponding fp8 PR
pytorch-labs/float8_experimental#263

Pull Request resolved: #126556
Approved by: https://github.com/wanchaol
ZelboK pushed a commit to ZelboK/pytorch that referenced this issue May 19, 2024
# Summary
Different take on this one:
pytorch#126338

We should probably not allow this mapping for 'compute' ops e.g. reductions

### Corresponding fp8 PR
pytorch-labs/float8_experimental#263

Pull Request resolved: pytorch#126556
Approved by: https://github.com/wanchaol
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants