Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: scipy rel_entr edge case #21193

Merged
merged 1 commit into from
May 28, 2024
Merged

fix: scipy rel_entr edge case #21193

merged 1 commit into from
May 28, 2024

Conversation

sh0416
Copy link
Contributor

@sh0416 sh0416 commented May 12, 2024

#21192

The value for edge case in rel_entr is 0, but the implementation set them to q distribution.
https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.special.rel_entr.html

I've fixed it, but need to be checked.

I am new for contributing this repo, so it would be pleased for someone to check this is right.

Copy link

google-cla bot commented May 12, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@jakevdp
Copy link
Collaborator

jakevdp commented May 12, 2024

Thanks for the PR! This looks right the right fix, but we should also add some test coverage. Could you also change these two tests to use rand_some_zero rather than rand_positive?

op_record("kl_div", 2, float_dtypes, jtu.rand_positive, True),
op_record(
"rel_entr", 2, float_dtypes, jtu.rand_positive, True,
),

@sh0416
Copy link
Contributor Author

sh0416 commented May 13, 2024

I've checked your suggestiion. As the rand_some_zero contains negative values, but the two functions should work with negative values, so it is ok to change them as you suggested. Thank you for guiding me. Is there anything I have to add or change the codebase?

@jakevdp
Copy link
Collaborator

jakevdp commented May 13, 2024

Hi - so it looks like there are problems with autodiff tests when we include zero. I guess I should have anticipated that: autodiff is evaluated with respect to finite differences, which will break down around domain edges. Let's try something like this instead:

    # test autodiff for kl_div & rel_entr with positive inputs:
    op_record("kl_div", 2, float_dtypes, jtu.rand_positive, True),
    op_record("rel_entr", 2, float_dtypes, jtu.rand_positive, True),
    # test without autodiff for inputs including zero:
    op_record("kl_div", 2, float_dtypes, jtu.rand_some_zero, False),
    op_record("rel_entr", 2, float_dtypes, jtu.rand_some_zero, False),

This gives us autodiff test coverage in the well-defined domain, and output test coverage in the corner cases.

@jakevdp jakevdp self-requested a review May 13, 2024 16:08
@sh0416
Copy link
Contributor Author

sh0416 commented May 14, 2024

Thank for the advice. I am doing it as is.

@sh0416
Copy link
Contributor Author

sh0416 commented May 14, 2024

Test name is duplicated I will handle immediately.

@sh0416
Copy link
Contributor Author

sh0416 commented May 14, 2024

I've check that the new test cases are passed except LaxScipySpcialFunctionsTest.testScipySpecialFun_ndtri_1x4_float32_ndtri which is irrelevant to this PR. It might be resolved in github CI, I wish..
@jakevdp Could you rerun the CI?

))
for rec in JAX_SPECIAL_FUNCTION_RECORDS
))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testScipySpecialFun(self, op, rng_factory, shapes, dtypes,
test_autodiff, nondiff_argnums):
test_autodiff, nondiff_argnums, test_name):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you say more about why this change is necessary? It doesn't look like test_name is used in the function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my chain of thought, the test_name variable is needed for generating the actual test name in the pretty_special_fun_name (to distinguish the test name for the new test case). However, pytest.parameterize carries whole values to the decorated function, so I just add them even though they are not used in the function body. I couldn't figure out how to do it without adding the new function argument, but if it is, I will change my implementation.

Copy link
Collaborator

@jakevdp jakevdp May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see - it looks like the test_name is not used anymore after we refactored these tests last year. I think this is the wrong approach.

Instead of adding more test cases that will run into duplicate name issues, let's leave the previous test cases alone and add a new manual test for each at the bottom of the file that checks the behavior at zero. What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, another thought if you're willing to look into it: I think the semantics of jax.scipy.special.xlogy match what we'd need for rel_entr, and there's a custom JVP rule to correctly handle gradients near zero. I'd bet if we re-implemented rel_entr in terms of xlogy, we could make the existing test work with the expanded domain. Is that something you want to explore?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Alternatively, you could just limit your change to the fix: replacing q with zero and I'll follow-up on the autodiff & testing questions)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like xlogy is tested only in the domain of positive values. I've tried to test xlogy with jtl.rand_some_zero, but it could not pass the test. I think the jvp rule shows some instability. I change the division operator with lax.div shows some stable results, but they failed on two test cases, named LaxScipySpcialFunctionsTest::testScipySpecialFun_xlogy_3x4_3x4_float32_float32 and LaxScipySpcialFunctionsTest::testScipySpecialFun_xlogy_3x1_2x1x4_float32_float32.

I already implemented rel_entr with xlogy. So, could you imagine why the jvp rule of xlogy is not work on the two testcases? I couldn't figure it out..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E           AssertionError:
E           Not equal to tolerance rtol=0.3, atol=0.003
E           JVP tangent
E           x and y -inf location mismatch:
E            x: array([[0.],
E                  [0.],
E                  [0.]], dtype=float32)
E            y: array([[  0.],
E                  [-inf],
E                  [  0.]], dtype=float32)

for LaxScipySpcialFunctionsTest.testScipySpecialFun_xlogy_3x1_2x1x4_float32_float32.

E           AssertionError:
E           Not equal to tolerance rtol=1.2, atol=0.012
E           JVP tangent
E           x and y nan location mismatch:
E            x: array([[nan, nan, nan, nan],
E                  [nan, nan, nan, nan],
E                  [nan, nan, nan, nan]], dtype=float32)
E            y: array([[-inf, -inf, -inf, -inf],
E                  [  0.,   0.,   0.,   0.],
E                  [  0.,   0.,   0.,   0.]], dtype=float32)

For LaxScipySpcialFunctionsTest.testScipySpecialFun_xlogy_3x4_3x4_float32_float32.

Copy link
Collaborator

@jakevdp jakevdp May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, then I was probably wrong about xlogy being applicable here. Thanks for checking. I suspect we'll need a custom JVP rule for rel_entr and kl_div in order to have correct gradient behavior at zero.

Again, I know that's probably more than you bargained for when embarking on this: let me know if you want to limit this PR to your original fix, which looks good to me.

Copy link
Contributor Author

@sh0416 sh0416 May 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's sort out these things.

First, I checked the behavior of scipy.special.xlogy.

>>> from scipy.special import xlogy
>>> xlogy([[1.1],[0],[-1.1]], [[1.1,0,-1.1]])
array([[ 0.1048412,       -inf,        nan],
       [ 0.       ,  0.       ,  0.       ],
       [-0.1048412,        inf,        nan]])

It seems that the output is well defined when y>=0 or x==0 (not nan region).

Second, I think whether the gradient, e.g., jvp, tangent, etc, could be computed through numerical approach, (f(x+eps)-f(x-eps))/(2*eps).

I've checked the code for testing the difference between custom jvp and numerical difference, and find out that they does not consider the nan value (only consider inf) and it is natural as the gradient could not be defined with nan value.

def numerical_jvp(f, primals, tangents, eps=EPS):
delta = scalar_mul(tangents, eps)
f_pos = f(*add(primals, delta))
f_neg = f(*sub(primals, delta))
return scalar_mul(safe_sub(f_pos, f_neg), 0.5 / eps)

safe_sub = partial(tree_map,
lambda x, y: _safe_subtract(x, y, dtype=_dtype(x)))

def _safe_subtract(x, y, *, dtype):
"""Subtraction that with `inf - inf == 0` semantics."""
with np.errstate(invalid='ignore'):
return np.where(np.equal(x, y), np.array(0, dtype),
np.subtract(x, y, dtype=dtype))

(Also, it is really intriguing to check the behavior of inf value).

>>> import numpy as np
>>> np.inf - np.inf
nan
>>> np.inf - (-np.inf)
inf
>>> -np.inf - (np.inf)
-inf
>>> -np.inf - (-np.inf)
nan

In this case, when we compute the jvp tangent at y=0, using numerical approach could not be valid as f(x-eps) would be nan.

Third, I think about a general approach to compute the numerical gradient at the boundary case and one possible solution is to use left gradient or right gradient, (f(x)-f(x-eps))/eps or (f(x+eps)-f(x))/eps where the terms in the numerator would be valid in this case.

However, I am scared of changing the check_grad in public_test_util.py as it might break lots of existing testcases, so I have to think about other approaches..

Copy link
Collaborator

@jakevdp jakevdp May 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I think I was wrong that xlogy has relevant semantics. It doesn't and you should ignore that suggestion.

And we should definitely not change check_grad – we should write a custom JVP rule similar to xlogy for rel_entr and kl_div to return the correct gradients for zero inputs, so that it passes the existing check_grads test.

You can tackle that if you wish, but I'm happy to do it as a followup as well because it's fairly subtle to get right. Your initial change – replacing of q with zero – fixes a bug, and I'd be happy to merge that change.

@jakevdp
Copy link
Collaborator

jakevdp commented May 28, 2024

Hi - are you still interested in working on this? This is a bug that we really need to get fixed.

@sh0416
Copy link
Contributor Author

sh0416 commented May 28, 2024

I will sort it out within a hour. sorry for the late reply.

@sh0416
Copy link
Contributor Author

sh0416 commented May 28, 2024

@jakevdp I'm done. I added a testcase for boundary value. I check that whole testcase are passed. Thank you for handling this PR with me.

jax/_src/scipy/special.py Outdated Show resolved Hide resolved
@sh0416 sh0416 requested a review from jakevdp May 28, 2024 13:07
@jakevdp
Copy link
Collaborator

jakevdp commented May 28, 2024

Thanks, looks good! Last thing: can you please squash your changes into a single commit? (see https://jax.readthedocs.io/en/latest/contributing.html#single-change-commits-and-pull-requests) This helps us keep the commit history clean and linear. Thanks!

@sh0416
Copy link
Contributor Author

sh0416 commented May 28, 2024

Thanks, looks good! Last thing: can you please squash your changes into a single commit? (see https://jax.readthedocs.io/en/latest/contributing.html#single-change-commits-and-pull-requests) This helps us keep the commit history clean and linear. Thanks!

I've done this. It seems clean. :) @jakevdp

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels May 28, 2024
@copybara-service copybara-service bot merged commit 82ad1da into google:main May 28, 2024
12 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants