-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: scipy rel_entr edge case #21193
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Thanks for the PR! This looks right the right fix, but we should also add some test coverage. Could you also change these two tests to use jax/tests/lax_scipy_special_functions_test.py Lines 145 to 148 in b4f2145
|
I've checked your suggestiion. As the rand_some_zero contains negative values, but the two functions should work with negative values, so it is ok to change them as you suggested. Thank you for guiding me. Is there anything I have to add or change the codebase? |
Hi - so it looks like there are problems with autodiff tests when we include zero. I guess I should have anticipated that: autodiff is evaluated with respect to finite differences, which will break down around domain edges. Let's try something like this instead: # test autodiff for kl_div & rel_entr with positive inputs:
op_record("kl_div", 2, float_dtypes, jtu.rand_positive, True),
op_record("rel_entr", 2, float_dtypes, jtu.rand_positive, True),
# test without autodiff for inputs including zero:
op_record("kl_div", 2, float_dtypes, jtu.rand_some_zero, False),
op_record("rel_entr", 2, float_dtypes, jtu.rand_some_zero, False), This gives us autodiff test coverage in the well-defined domain, and output test coverage in the corner cases. |
Thank for the advice. I am doing it as is. |
Test name is duplicated I will handle immediately. |
I've check that the new test cases are passed except |
)) | ||
for rec in JAX_SPECIAL_FUNCTION_RECORDS | ||
)) | ||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. | ||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion | ||
def testScipySpecialFun(self, op, rng_factory, shapes, dtypes, | ||
test_autodiff, nondiff_argnums): | ||
test_autodiff, nondiff_argnums, test_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you say more about why this change is necessary? It doesn't look like test_name
is used in the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my chain of thought, the test_name
variable is needed for generating the actual test name in the pretty_special_fun_name
(to distinguish the test name for the new test case). However, pytest.parameterize
carries whole values to the decorated function, so I just add them even though they are not used in the function body. I couldn't figure out how to do it without adding the new function argument, but if it is, I will change my implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see - it looks like the test_name
is not used anymore after we refactored these tests last year. I think this is the wrong approach.
Instead of adding more test cases that will run into duplicate name issues, let's leave the previous test cases alone and add a new manual test for each at the bottom of the file that checks the behavior at zero. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, another thought if you're willing to look into it: I think the semantics of jax.scipy.special.xlogy
match what we'd need for rel_entr
, and there's a custom JVP rule to correctly handle gradients near zero. I'd bet if we re-implemented rel_entr
in terms of xlogy
, we could make the existing test work with the expanded domain. Is that something you want to explore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Alternatively, you could just limit your change to the fix: replacing q
with zero
and I'll follow-up on the autodiff & testing questions)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like xlogy
is tested only in the domain of positive values. I've tried to test xlogy
with jtl.rand_some_zero
, but it could not pass the test. I think the jvp rule shows some instability. I change the division operator with lax.div
shows some stable results, but they failed on two test cases, named LaxScipySpcialFunctionsTest::testScipySpecialFun_xlogy_3x4_3x4_float32_float32
and LaxScipySpcialFunctionsTest::testScipySpecialFun_xlogy_3x1_2x1x4_float32_float32
.
I already implemented rel_entr with xlogy. So, could you imagine why the jvp rule of xlogy
is not work on the two testcases? I couldn't figure it out..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
E AssertionError:
E Not equal to tolerance rtol=0.3, atol=0.003
E JVP tangent
E x and y -inf location mismatch:
E x: array([[0.],
E [0.],
E [0.]], dtype=float32)
E y: array([[ 0.],
E [-inf],
E [ 0.]], dtype=float32)
for LaxScipySpcialFunctionsTest.testScipySpecialFun_xlogy_3x1_2x1x4_float32_float32
.
E AssertionError:
E Not equal to tolerance rtol=1.2, atol=0.012
E JVP tangent
E x and y nan location mismatch:
E x: array([[nan, nan, nan, nan],
E [nan, nan, nan, nan],
E [nan, nan, nan, nan]], dtype=float32)
E y: array([[-inf, -inf, -inf, -inf],
E [ 0., 0., 0., 0.],
E [ 0., 0., 0., 0.]], dtype=float32)
For LaxScipySpcialFunctionsTest.testScipySpecialFun_xlogy_3x4_3x4_float32_float32
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, then I was probably wrong about xlogy
being applicable here. Thanks for checking. I suspect we'll need a custom JVP rule for rel_entr
and kl_div
in order to have correct gradient behavior at zero.
Again, I know that's probably more than you bargained for when embarking on this: let me know if you want to limit this PR to your original fix, which looks good to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's sort out these things.
First, I checked the behavior of scipy.special.xlogy
.
>>> from scipy.special import xlogy
>>> xlogy([[1.1],[0],[-1.1]], [[1.1,0,-1.1]])
array([[ 0.1048412, -inf, nan],
[ 0. , 0. , 0. ],
[-0.1048412, inf, nan]])
It seems that the output is well defined when y>=0
or x==0
(not nan region).
Second, I think whether the gradient, e.g., jvp, tangent, etc, could be computed through numerical approach, (f(x+eps)-f(x-eps))/(2*eps).
I've checked the code for testing the difference between custom jvp and numerical difference, and find out that they does not consider the nan value (only consider inf) and it is natural as the gradient could not be defined with nan value.
jax/jax/_src/public_test_util.py
Lines 211 to 215 in 6fe313c
def numerical_jvp(f, primals, tangents, eps=EPS): | |
delta = scalar_mul(tangents, eps) | |
f_pos = f(*add(primals, delta)) | |
f_neg = f(*sub(primals, delta)) | |
return scalar_mul(safe_sub(f_pos, f_neg), 0.5 / eps) |
jax/jax/_src/public_test_util.py
Lines 187 to 188 in 6fe313c
safe_sub = partial(tree_map, | |
lambda x, y: _safe_subtract(x, y, dtype=_dtype(x))) |
jax/jax/_src/public_test_util.py
Lines 169 to 173 in 6fe313c
def _safe_subtract(x, y, *, dtype): | |
"""Subtraction that with `inf - inf == 0` semantics.""" | |
with np.errstate(invalid='ignore'): | |
return np.where(np.equal(x, y), np.array(0, dtype), | |
np.subtract(x, y, dtype=dtype)) |
(Also, it is really intriguing to check the behavior of inf value).
>>> import numpy as np
>>> np.inf - np.inf
nan
>>> np.inf - (-np.inf)
inf
>>> -np.inf - (np.inf)
-inf
>>> -np.inf - (-np.inf)
nan
In this case, when we compute the jvp tangent at y=0, using numerical approach could not be valid as f(x-eps) would be nan.
Third, I think about a general approach to compute the numerical gradient at the boundary case and one possible solution is to use left gradient or right gradient, (f(x)-f(x-eps))/eps or (f(x+eps)-f(x))/eps where the terms in the numerator would be valid in this case.
However, I am scared of changing the check_grad
in public_test_util.py
as it might break lots of existing testcases, so I have to think about other approaches..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I think I was wrong that xlogy
has relevant semantics. It doesn't and you should ignore that suggestion.
And we should definitely not change check_grad
– we should write a custom JVP rule similar to xlogy
for rel_entr
and kl_div
to return the correct gradients for zero inputs, so that it passes the existing check_grads
test.
You can tackle that if you wish, but I'm happy to do it as a followup as well because it's fairly subtle to get right. Your initial change – replacing of q
with zero
– fixes a bug, and I'd be happy to merge that change.
Hi - are you still interested in working on this? This is a bug that we really need to get fixed. |
I will sort it out within a hour. sorry for the late reply. |
@jakevdp I'm done. I added a testcase for boundary value. I check that whole testcase are passed. Thank you for handling this PR with me. |
Thanks, looks good! Last thing: can you please squash your changes into a single commit? (see https://jax.readthedocs.io/en/latest/contributing.html#single-change-commits-and-pull-requests) This helps us keep the commit history clean and linear. Thanks! |
I've done this. It seems clean. :) @jakevdp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks!
#21192
The value for edge case in rel_entr is 0, but the implementation set them to q distribution.
https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.special.rel_entr.html
I've fixed it, but need to be checked.
I am new for contributing this repo, so it would be pleased for someone to check this is right.