Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing topology writer to pass hints to accelerator backends #126393

Open
kbadz opened this issue May 16, 2024 · 0 comments 路 May be fixed by #121639
Open

Allowing topology writer to pass hints to accelerator backends #126393

kbadz opened this issue May 16, 2024 · 0 comments 路 May be fixed by #121639
Labels
module: backend non-standard backend support module: higher order operators torch.cond and similar module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@kbadz
Copy link

kbadz commented May 16, 2024

馃殌 The feature, motivation and pitch

High level idea:
We would sometimes like the end user to be allowed to add additional context into our backend. Ideally, we would do that in a way that could allow him to attach this information to specific operations of his choice. Later in the flow, backend can utilize this information to provide better optimization or execution scheme resembling more what user expected. It allows more versatile programmability. Examples are, but are not limited to: choosing the streams, deciding on order of execution, hinting slicing schemes, bundling sets of operations together, and so on. In the case of scheduling we have various options, like BFS, DFS, or even more custom execution order (strict scheduling) provided by the user to satisfy his needs. These "hints" can be optional to be supported by the backend, or they might be supported but backend might decide to ignore them if they cannot be fullfilled or fullfilling them would break other mechanisms functional-wise or performance-wise.

Specific use case:
For purpose of this ticket, let's focus on very specific use case that we have currently in our Gaudi stack. For flash attention operation, we can execute its' batch*num_heads parallel slices either in DFS or BFS manner. For various operations, it is quite often more performant to go with BFS because it gives more parallel paths that can be executed simultaneously at a given time, thus making it easier to parallelize the work across the execution engines. But in the case of flash attention, where the size of memory needed might be a significant issue, up to the point where topology cannot fit on the accelerator, it might be better to go DFS, because that way we only need to allocate memory for a single parallel slice at a time, and then we can reuse all the intermediate tensors for next paths. That way we might sacrifice some of the raw performance, but also reduce the amount of memory needed to compute this operation almost linearly with the amount of slices available. Such trade-off might be critical for other accelerators with smaller amount of memory.

Limitation to solve:
Unfortunately, backend compilers usually can't be completely sure what is the best approach to compute arbitrary subgraph because they don't know what is important for the user, and they don't know what are the expected constraints E2E-workload-wise. This is why sometimes it is good to allow user to provide some kind of "hints". And those "hints" would benefit also other backends that would like to give user this kind of choices.

Design considerations:
First consideration is - from end-user point of view, to achieve this kind of programmability, he could either:

  1. For every chunk of operations user would like to hint, wrap their inputs and outputs with custom OPs containing those hints and let backend graph analysis infer the hints regions from those based on data flow. Somewhat similarly to what partitioner is doing.
  2. Reuse Higher Order OP (HOO) mechanism to wrap groups, possibly nested, of operations with specific hints and provide them via meta field for these ops.

Our recommendation is the second option, because first one would require either backend or pytorch to implement tricky graph analysis to automagically wrap the nodes, while HOO does that by design during the tracing. Also, user experience is much better when using HOO, because one does not have to track the unobvious data flow to correctly wrap more complicated cases with nested hinting.

Second consideration is - if we go into the HOO option, we need to choose between the flavor we would use it:

  1. Fully generic hinted_context HOO.
    pros:
    everyone can use it for everything
    cons:
    hints could be easily lost in graph transformations due to many possible hints in such generic case

  2. Set of dedicated ops, for presented use case it would be graph_schedule HOO as an example.
    pros:
    clear purpose of specific OP
    easier to track such OP inside transformations
    still can be used by everyone
    cons:
    need to create such OP for every usage, could prove tedious and hard to maintain at some point
    there is still some risk of losing the hints somewhere on the way

  3. Accelerator-specific (e.g. Gaudi only) custom HOO that can be registered by backends at runtime.
    pros:
    every backend need to care only about its own custom HOOs
    cons:
    end-user need to track and use different OPs depending on accelerator he wants to use
    migrating workloads between accelerators would become harder
    there is no such custom HOO registration mechanism anyway

Our recommendation is to go with first option. It is easiest to implement, maintain and use. But let's discuss its' only caveat of losing the hints on the way:

First risk is losing those in torch.compile implementation in the framework itself. This can be mitigated by allowing these OPs only in autograd.Function custom implementations which are traced only at stage of AOT, which is quite far in the stack and there are no (at least I'm not aware of those) other transformations that could pose significant risk from that point.

Second risk is losing hints in the accelerator backend. As this is optional feature, every backend is free to opt-in by implementing safe graph transformations with the hints it declares it supports. There is no risk at all for backends that do not want to support hints at all, it is free to ignore them.

Third consideration is - how to emit "hint" information to the accelerator backend:

  1. Using nodes metadata field.
    pros:
    easy to implement
    easy to use if wanted
    easy to ignore if not wanted
    cons:
    easy to lose

  2. Injecting additional "hint" nodes into graph that would tie closely with hinted nodes.
    pros:
    harder to lose
    cons:
    much harder to implement
    going to be visible in backends that do not want this feature at all

Our recommendation is to go with the metadata option, as it is easiest and can be completely transparent for backends not interested in the feature or not ready for it. But backends using this feature would need to pay close attention to nodes metadata.

Feature PR that goes along with our recommendations: #121639

Alternatives

No response

Additional context

No response

cc @bdhirsh @ezyang @msaroufim @anijain2305 @chauhang @zou3519 @ydwu4

@mikaylagawarecki mikaylagawarecki added module: backend non-standard backend support oncall: pt2 module: higher order operators torch.cond and similar labels May 20, 2024
@xmfan xmfan added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 20, 2024
@zou3519 zou3519 added the module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, label May 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: backend non-standard backend support module: higher order operators torch.cond and similar module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants