Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pallas] align interpreter load/store with masked behaviour #21298

Merged

Conversation

oliverdutton
Copy link
Contributor

@oliverdutton oliverdutton commented May 18, 2024

(and adds stride support)

Fixes #21143

Implements a jittable masked gather/scatter where for load/store/swap any masked indexing does not occur.

For load it sets any masked indexing to index to the first element in the array instead.

For swap(/store) it also sets masked indexing to the first element (and then deals with special rules to make sure the first element is dealt with correctly)

The currently used dynamic_slices are replaced with explicit index materialisation and gathers/scatters.
The advantage of doing it this way is that you can combine it with checkify(f, errors=checkify.index_checks) in interpreter mode to check for any unmasked OOB indexing which is (I think, and believe should be) undefined behaviour.

[apologies this is a reopening of a previous request I'd done badly having not checked contributing.md]

@oliverdutton
Copy link
Contributor Author

oliverdutton commented May 24, 2024

This corrects similar issues to #21180 @justinjfu, though relating to indexing into MemRefs rather than non evenly-divisible block shapes for chunking arrays

@justinjfu justinjfu self-assigned this May 29, 2024
Copy link
Collaborator

@justinjfu justinjfu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Really appreciate the fixes.

jax/_src/pallas/primitives.py Outdated Show resolved Hide resolved
indices = tuple(jnp.meshgrid(*indices, indexing='ij'))
if mask is not None:
# masked loads set to index first element in array
indices = tuple(jnp.where(mask, indexs, 0) for indexs in indices)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to set masked indices to the first element? Wouldn't these values be masked away anyways on line 328?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's true these values will be masked away on 328. The masked indices being set to the first element is to match the described behaviour of masked loads more closely, the documentation says

Triton avoids the read from/write to memory if it’s masked

as part of

Masking is important when doing out-of-bounds loads/stores. The operational semantics of masking can be compiler-determined (if we understand the documentation properly, Triton avoids the read from/write to memory if it’s masked).
(https://jax.readthedocs.io/en/latest/pallas/design.html)

Those masked indices might be OOB values and should not be indexed at. By diverting them to the first index this is the best jittable approximation I could make.

The link to checkify is that this change makes it possible to distinguish between an unmasked OOB indexing that will actually happen on device and a a masked OOB indexing that will. So the masked load is compatible with debugging in JAX for NaNs by any variant of https://jax.readthedocs.io/en/latest/debugging/flags.html

e.g. for Ref of shape (4,) and indices of [1,2,3,7], mask of [True,True,True,False] is a completely valid load, however running checkify check_nan's on the interpretted code would throw an OOB error without this fix as the indexing at 7 would 'physically' occur, the divert to 0 avoids this.

Copy link
Collaborator

@justinjfu justinjfu May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - OOB indexes that are masked being legal does make the problem more complicated.

The cleaner way to handle this for checkify is probably to add a custom checkify rule for the load and swap primitives (for example, see https://github.com/google/jax/blob/main/jax/_src/checkify.py#L608). The interpret mode discharge rule is primarily responsible for returning the correct values.

What I propose is this change:

  1. In the discharge rule for load and swap, simply pad the inputs to the maximum shape with jax.numpy.pad at the beginning of the call, then you can leave the rest of the logic unchanged.
  2. Move the explicit indexing logic to a custom checkify rule for load & swap which will check if there are any unmasked indices that are OOB. If they exist, return the OOB error. I believe checkify only checks for NaNs for a subset of arithmetic ops, so you wouldn't need to do NaN checking here.

I think this would keep the discharge implementation simple and fast but also be compatible with checkify & jit. Does that sound reasonable or are there other use cases which wouldn't be covered?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correctness

  1. In the discharge rule for load and swap, simply pad the inputs to the maximum shape with jax.numpy.pad at the beginning of the call, then you can leave the rest of the logic unchanged.

This equates to padding by the slice size in each dimension. Agreed this will give correct output without performance hit. Will implement.

Debugging interaction with masked OOBs

  1. Move the explicit indexing logic to a custom checkify rule for load & swap which will check if there are any unmasked indices that are OOB. If they exist, return the OOB error. I believe checkify only checks for NaNs for a subset of arithmetic ops, so you wouldn't need to do NaN checking here.

I think it makes sense for that to be dealt with by the NaN/OOB checkers. Will look into how to do that and push that into a separate PR to this one.

This PR will be just to fix #21143 and do 1.

Will put together on Monday evening

tests/pallas/pallas_test.py Outdated Show resolved Hide resolved
tests/pallas/pallas_test.py Outdated Show resolved Hide resolved
jax/_src/pallas/primitives.py Outdated Show resolved Hide resolved
@oliverdutton oliverdutton force-pushed the pallas_interpreter_indexing_fix branch from 341e249 to d66b107 Compare June 2, 2024 21:57
@oliverdutton oliverdutton force-pushed the pallas_interpreter_indexing_fix branch from d66b107 to afc52fe Compare June 2, 2024 22:06
@oliverdutton
Copy link
Contributor Author

oliverdutton commented Jun 3, 2024

This PR is now solely aimed at fixing dynamic slice store/load operations in the interpreter identified in #21143. In the interpreter it pads the arrays with uninitialised values so dynamic slices are never pushed back into the 'true' array.

Tests for both load and store are added. I've confirmed they error pre this commit and are fixed by this change (as is the original bug report).

Separate PRs may be made to solve separate issues:

  • Special rules for OOB and NaN checking in checkify for masked store/load
  • Slice strides not equal to 1

Copy link
Collaborator

@justinjfu justinjfu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Looks great, just some minor comments.

@oliverdutton oliverdutton force-pushed the pallas_interpreter_indexing_fix branch 2 times, most recently from da0b2d2 to 9218a3c Compare June 4, 2024 22:08
@oliverdutton oliverdutton force-pushed the pallas_interpreter_indexing_fix branch from 9218a3c to fa29f18 Compare June 4, 2024 22:20
@oliverdutton
Copy link
Contributor Author

Comments implemented and now should be ready

Copy link
Collaborator

@justinjfu justinjfu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved.

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jun 4, 2024
@copybara-service copybara-service bot merged commit 7d913f7 into google:main Jun 7, 2024
12 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[pallas] Interpreter mismatch for masked OOB indexing
3 participants