Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Iteratively reweighted least squares for robust fitting #3170

Open
wants to merge 11 commits into
base: master
Choose a base branch
from

Conversation

samcoveney
Copy link
Contributor

@samcoveney samcoveney commented Apr 4, 2024

Adds ability to do iteratively re-weighted least squares for DTI and DKI (with/without positivity constraints) by defining a function to return weights based on the residuals of the last fit.

Implementation is in:

  • iterative_fit_tensor in reconst/dti.py
  • iterative_fit in reconst/dki.py

The methods work by iterating over a standard fit method, such as WLS, NLLS, or CWLS, but between fitting iterations weights are calculated from a supplied function weights_method. Some examples are given in reconst/weights_method.py, but users can define their own functions outside of DIPY and pass the methods as arguments. Outlier rejection is naturally handled by the weights_method supplied.

Example:

# use robust NLLS from dti.py
dkimodel = dki.DiffusionKurtosisModel(gtab, fit_method="RNLLS", return_S0_hat=True, num_iter=10)
dkifit = dkimodel.fit(data, mask=mask)
robust = dkimodel.extra["robust"]  # obtain mask of robust signals

# use robust CWLS from dki.py
from dipy.reconst.weights_method import weights_method_wls_gm as wc
dkimodel = dki.DiffusionKurtosisModel(gtab, fit_method="CWLS", return_S0_hat=True, weights_method=wc, num_iter=10)
dkifit = dkimodel.fit(data, mask=mask)
robust = dkimodel.extra["robust"]  # obtain mask of robust signals

Notes:

  • It was simpler to implement without a stopping condition on individual voxels
  • I re-implemented RESTORE and NLLS slightly, I think they are better now (we could actually implement RESTORE entirely within this new framework)
  • By fitting to all voxels between iterations (rather than iterating over individual voxels, before moving on to the next voxel) it is possible to define weights based on multiple voxels. I will include my own weight functions for this, implemented in DiPy already, in a future PR

This is a fairly major extension that I think can be applied to other models using the same framework. I am sharing now to benefit everyone, ahead of my own publications, so please be kind!

Also, this has been a lot of work. Sorry I haven't done:

  • tests
  • tutorials / examples
  • full pep8 compliance

but I would like to get feedback at this stage please

@pep8speaks
Copy link

pep8speaks commented Apr 4, 2024

Hello @samcoveney, Thank you for updating !

Line 90:24: E712 comparison to False should be 'if cond is False:' or 'if not cond:'
Line 185:24: E712 comparison to False should be 'if cond is False:' or 'if not cond:'

Comment last updated at 2024-04-26 15:51:29 UTC

@skoudoro
Copy link
Member

skoudoro commented Apr 4, 2024

Thank you for this @samcoveney !

@RafaelNH and @arokem, it would be great if you could look at this!

@samcoveney samcoveney changed the title Robust split idea Iteratively reweighted least squares for robust fitting Apr 5, 2024
@RafaelNH
Copy link
Contributor

RafaelNH commented Apr 5, 2024

Many thanks for this PR. I will give a look asap.

@samcoveney et al - In the meantime, please give a look to the PR #3151. I will be great to have that PR merged asap, so that we can avoid some duplicate work here. Particularly if you noticed I increased the dki.py test coverage to 100%, in PR #3151, it will be great if we maintain that test coverage in the work here as well.

@samcoveney
Copy link
Contributor Author

@RafaelNH I will take a look at #3151 asap - your PR should be merged first and I will follow later with rebase if needed (not seeing much clash at the moment, but I haven't written my tests for this PR yet...). Thanks for taking a look here too

Copy link

codecov bot commented Apr 5, 2024

Codecov Report

Attention: Patch coverage is 93.78698% with 21 lines in your changes are missing coverage. Please review.

Project coverage is 83.92%. Comparing base (16d9153) to head (b3fef2e).
Report is 5 commits behind head on master.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #3170      +/-   ##
==========================================
+ Coverage   83.73%   83.92%   +0.18%     
==========================================
  Files         153      154       +1     
  Lines       21316    21540     +224     
  Branches     3440     3494      +54     
==========================================
+ Hits        17849    18077     +228     
+ Misses       2611     2604       -7     
- Partials      856      859       +3     
Files Coverage Δ
dipy/reconst/multi_voxel.py 95.78% <100.00%> (+1.19%) ⬆️
dipy/reconst/dki.py 96.84% <96.70%> (-0.16%) ⬇️
dipy/reconst/weights_method.py 90.12% <90.12%> (ø)
dipy/reconst/dti.py 95.27% <92.90%> (+2.49%) ⬆️

... and 2 files with indirect coverage changes

@skoudoro
Copy link
Member

skoudoro commented Apr 9, 2024

When do you think you will have time to look at this @arokem?

Thank you in advance for your update

@arokem
Copy link
Contributor

arokem commented Apr 9, 2024

Sorry - busy week, so not before next Monday at earliest.

@samcoveney
Copy link
Contributor Author

I'll try to get tests etc done in the meantime

Copy link
Contributor

@RafaelNH RafaelNH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @samcoveney! Thanks for the PR and your hard work here!
I did a first review of it. I haven't managed yet to read all the suggested alterations due to it volume, but I have already some comments that I would like you to address before carrying on. Basically, I want to make sure that our previous DKI core implementations are not affected by the changes of your new fitting method, for instance you changed the weights of our previous WLS from diag(fn**2) to diag(fn). Could you clarify the rationale behind this adjustment? I'm concerned it might affect the outcomes of previous implementations, unless I'm missing something. Moreover, since what you are proposing is something new an tutorial explaining the method will help the review process. Also make sure to add some tests - for example, I quickly tried to run your new implementations in single voxel simulations and it is giving me an error.

dipy/reconst/weights_method.py Show resolved Hide resolved
dipy/reconst/weights_method.py Show resolved Hide resolved
dipy/reconst/dki.py Show resolved Hide resolved
dipy/reconst/dki.py Show resolved Hide resolved
dipy/reconst/dki.py Show resolved Hide resolved
@samcoveney
Copy link
Contributor Author

samcoveney commented Apr 10, 2024

Thanks for the review @RafaelNH - I appreciate getting some heads up about the overall structure, if okay?

You are right, that a tutorial and some tests will help. If you can let me know if you think the overall implementation (in terms of form) is alright, and can get onto adding these tests.

(on the PR, I give some examples of use - does this sort of workflow look okay? I think it was better to allow the user to define the weight functions for the best generality)

Could you give me some information on the error please? In dki.py I recall perhaps an issue with the single voxel fit case, which is based on ndim == 1, perhaps the fact of single voxel is the matter...

dipy/reconst/dki.py Outdated Show resolved Hide resolved
dipy/reconst/dti.py Outdated Show resolved Hide resolved
@RafaelNH
Copy link
Contributor

RafaelNH commented Apr 10, 2024

Thanks for the review @RafaelNH - I appreciate getting some heads up about the overall structure, if okay?

You are right, that a tutorial and some tests will help. If you can let me know if you think the overall implementation (in terms of form) is alright, and can get onto adding these tests.

(on the PR, I give some examples of use - does this sort of workflow look okay? I think it was better to allow the user to define the weight functions for the best generality)

Could you give me some information on the error please? In dki.py I recall perhaps an issue with the single voxel fit case, which is based on ndim == 1, perhaps the fact of single voxel is the matter...

Looks good for me so far - I still have to understand better the code e.g. why number of iterations have to be at least 4. But the structure looks good, specially if it resembles what we have in dti.py. If you want to be sure (and avoid unnecessary work), you can wait for the feedback from @arokem before implementing the tests and example.

Here is the error that I get when trying to run your implementation in single voxel signals (based on the variables in test_dki.py):

Code
dkimodel = dki.DiffusionKurtosisModel(gtab_2s, fit_method="RNLLS")
dkifit = dkimodel.fit(signal_cross)
robust = dkimodel.extra["robust"]

error:
dipy\dipy\reconst\weights_method.py:165: RuntimeWarning: invalid value encountered in divide w = C**2 / (C**2 + residuals**2)**2
dipy\reconst\weights_method.py:165: RuntimeWarning: invalid value encountered in divide w = C**2 / (C**2 + residuals**2)**2
dipy\reconst\weights_method.py:165: RuntimeWarning: invalid value encountered in divide w = C**2 / (C**2 + residuals**2)**2

Runs fine for:

dkifit = dkimodel.fit(DWI)
robust = dkimodel.extra["robust"]

Thanks again for your hard work on this!

@samcoveney
Copy link
Contributor Author

samcoveney commented Apr 11, 2024

Thanks @RafaelNH I think this might be related to the test you are doing having no error on the data - can I confirm that this is true? I thought I had handled this, but perhaps not!

Either way, I will add a test for such a case.

edit: if you can send me your full test e.g. DM me on the discord, that would be awesome. Thanks again for helping

edit 2: num_iter conditions are specified in weight functions examples provided, due to the specific weighting schemes implemented in those weights functions. There are some comments within those functions to explain it, but I will update the docs.

@samcoveney
Copy link
Contributor Author

@RafaelNH when you get the chance, can you confirm whether the test you performed was on noiseless data please?

@RafaelNH
Copy link
Contributor

@RafaelNH when you get the chance, can you confirm whether the test you performed was on noiseless data please?

Yes, I test it on noiseless synthetic signals.

@samcoveney
Copy link
Contributor Author

Thanks @RafaelNH I have added tests for DTI, will get them done for DKI and try to address what to do when fitting noiseless signals (issue is that the noise is zero in this case: imagine weighting by inverse standard deviation of noise, when the noise is zero!)

@samcoveney
Copy link
Contributor Author

Have updated dki tests as well, should all be done.

@RafaelNH I note that you didn't get an error, but just a warning, is this okay given the user error to use robust weighting with noiseless data?

@skoudoro @arokem @Garyfallidis can you also take a look please? Probably I need to add a tutorial, but if that can wait that would be great....

@skoudoro
Copy link
Member

Yes, I will take a look today.

Can you just make sure to address some failing CI's like:

  • typo CI's, there is a typo.
  • ruff CI's., issue with import order
  • Warnings: we switched to zero warnings policy on the CI's and I can see some warnings popping up during tests. Can you address them?

Thank you !

@samcoveney
Copy link
Contributor Author

samcoveney commented Apr 26, 2024

Warnings: we switched to zero warnings policy on the CI's and I can see some warnings popping up during tests. Can you address them?

I must have missed this, both the warnings (will have to find those) and the discussion of this change...

@samcoveney
Copy link
Contributor Author

Note that the code format checks often fail while insisting on things that would fail pep8... shame.

Also:

Line 90:24: E712 comparison to False should be 'if cond is False:' or 'if not cond:'
Line 185:24: E712 comparison to False should be 'if cond is False:' or 'if not cond:'

Please ignores these, they are literally wrong warnings, where pep8 doesn't understand comparisons with NumPy arrays of bools

@skoudoro
Copy link
Member

Note that the code format checks often fail while insisting on things that would fail pep8... shame.

indeed 🤣 , Thank you for catching that.

@jhlegarreta, can you check line-length for ruff, I suppose the default is wrong and create incompatibility between ruff and pep8check. We should also check pep8check.

We should document it and it seems there is a consensus with line-length=88 from different project. We can update that everywhere

Please ignores these, they are literally wrong warnings, where pep8 doesn't understand comparisons with NumPy arrays of bools

we can not 😅 , you can do like here:

non_outlier_idx = np.where(np.logical_not(cond))

@jhlegarreta
Copy link
Contributor

jhlegarreta commented Apr 26, 2024

@jhlegarreta, can you check line-length for ruff, I suppose the default is wrong and create incompatibility between ruff and pep8check. We should also check pep8check.

Not sure if I follow this: we are not using ruff to ensure the line length yet. It's pep8speaks that checks this. And pep8speaks is set to 79 if I'm not wrong.

@skoudoro
Copy link
Member

Not sure if I follow this: we are not using ruff to ensure the line length yet. It's pep8speaks that checks this. And pep8speaks is set to 79 if I'm not wrong.

look here: https://github.com/dipy/dipy/actions/runs/8850919427/job/24306287312?pr=3170#step:4:135

ruff enforces the one line import from dipy.core.geometry import cart2sphere, perpendicular_directions, sphere2cart but the length is 82. which is not compatible with pep8check

@jhlegarreta
Copy link
Contributor

look here: https://github.com/dipy/dipy/actions/runs/8850919427/job/24306287312?pr=3170#step:4:135
ruff enforces the one line import from dipy.core.geometry import cart2sphere, perpendicular_directions, sphere2cart but the length is 82. which is not compatible with pep8check

OK, then the solution is to either add a # noqa: E501 to the lines at issue so that pep8speaks does not complain, or tell ruff not to sort the lines at issue (have not looked in depth, but something like # noqa: F401). I'd go with the first option.

@samcoveney
Copy link
Contributor Author

samcoveney commented Apr 26, 2024

we can not 😅 , you can do like here:

non_outlier_idx = np.where(np.logical_not(cond))

Hmm does this do what is needed? Surely that would need argwhere, for starters?

Anyway, the changes in this PR are quite substantial, if we can focus on that, it'd be great thanks

@samcoveney
Copy link
Contributor Author

samcoveney commented Apr 26, 2024

Looks like cls_fit_dki is randomly failing on noisy data with valid minimum signal. This is true whether or not using my new code here... I will look into it next week, could be kwargs = {"cvxpy_solver": cvxpy.CLARABEL} is produces an error, but if I do kwargs = {"cvxpy_solver": cvxpy.ECOS} then I just get a warning. Either way, this suggests that just taking valid data and adding noise often means the constrained solver fails, so something is going wrong here.

I think this may all be related to CWLS when the data is noisy. No tests were created for constrained fitting on noisy data, and since the test data was already 'valid' it is possible that the constrained fitting was doing nothing in the tests. (In other words, this problem may have already existed...)

Possible: the constraints are set up for the log signal, but don't work as planned for the weighted log signal. I'll look into it, any thoughts appreciated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants