Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deep-learning-based registration example #2137

Open
wants to merge 17 commits into
base: master
Choose a base branch
from

Conversation

sarathknv
Copy link
Contributor

A simple example of deep-learning-based non-rigid image registration with the MNIST dataset, taking ideas from

Framework

framework

An encoder-decoder network is trained to output a dense displacement field that is used to warp the moving image to align with the fixed image. I've used MNIST images of one class, same as that of the fixed image, for training and testing.

Sample results

Colab notebook

https://colab.research.google.com/drive/1fTzz1aT2sb8oAXRO1-dr6O_IR6dof36e

Training takes about 5 minutes on GPU.

@sarathknv
Copy link
Contributor Author

I've refactored my code a little. @BramshQamar @skoudoro @ShreyasFadnavis @Garyfallidis your thoughts on this.

Usage:

from dipy.nn.registration import SimpleFCN2d
from dipy.nn.metrics import NormalizedCrossCorrelation, MeanSquaredError


ncc = NormalizedCrossCorrelation()

model = SimpleFCN2d()
model.compile(optimizer='adam', loss=ncc)

# train
model.fit(x_train, y_train, validation_split=0.2, batch_size=8, epochs=10)

# evaluate on test set
model.evaluate(x_test, y_test, batch_size=8)

# save weights
model.save_weights('./weights_fcn2d.h5')

# load_weights
model.load_weights('weights_fcn2d.h5')

# for single/batch input
moved = model.predict_on_batch({'moving': moving, 'static': static})

If this API is okay, I'll go ahead and clean it and add the test file.

@codecov
Copy link

codecov bot commented May 24, 2020

Codecov Report

Merging #2137 into master will decrease coverage by 2.12%.
The diff coverage is 78.16%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #2137      +/-   ##
==========================================
- Coverage   91.37%   89.25%   -2.13%     
==========================================
  Files         251      259       +8     
  Lines       32514    32864     +350     
  Branches     3423     3444      +21     
==========================================
- Hits        29711    29333     -378     
- Misses       2057     2792     +735     
+ Partials      746      739       -7     
Impacted Files Coverage Δ
dipy/nn/registration/fcn.py 24.19% <24.19%> (ø)
dipy/nn/registration/__init__.py 66.66% <66.66%> (ø)
dipy/nn/registration/dataloader.py 80.76% <80.76%> (ø)
dipy/nn/metrics.py 85.00% <85.00%> (ø)
dipy/nn/tests/test_metrics.py 86.36% <86.36%> (ø)
dipy/nn/tests/test_unet.py 91.80% <91.80%> (ø)
dipy/nn/registration/_utils.py 94.11% <94.11%> (ø)
dipy/nn/registration/unet.py 94.66% <94.66%> (ø)
dipy/viz/app.py 4.46% <0.00%> (-72.96%) ⬇️
dipy/viz/tests/test_apps.py 24.56% <0.00%> (-70.18%) ⬇️
... and 22 more

@ShreyasFadnavis
Copy link
Member

Hi @sarathknv ... I quite like this API design. Would like some other people's opinions on this!

@skoudoro did you get a chance to see if we can unify the metrics that we already have in DIPY with TF?

@finalelement can you take a look too? Does this abide by your PR? Any other stuff you'd like the API to have/ we missed out on?

@skoudoro
Copy link
Member

Hi @sarathknv,

Thank you for this update!

  1. Personally, I prefer when the compile method is in the __init__ of your class like MultipleLayersPerceptron class. I do not think the user needs this extra work. However, I like that your loss parameter is a callable. I think we should keep this point.

  2. Concerning load_weights/save_weights, I think it is good, and you should keep these methods.

  3. I am not sure what is the difference between predict and predict_on_batch. Can you tell me more? Can it just be an option of predict?

  4. Concerning the metric. Do you think you will add more methods to these classes? If not, we do not need a class and you can keep them as a function with a decorator. A class with only __init__ and __call__ methods is a sign that it can be simpler.

  5. There a lot of code.... too much code... it will be hard to review. Please, consider 1 PR = 1 model. You should avoid adding 3-4 models on the same PR. Let's focus on 1 model and make sure that we all agree with the API.

Thank you

@sarathknv
Copy link
Contributor Author

Thanks for the feedback @ShreyasFadnavis and @skoudoro ☺️

  1. Personally, I prefer when the compile method is in the init of your class like MultipleLayersPerceptron class. I do not think the user needs this extra work. However, I like that your loss parameter is a callable. I think we should keep this point.

I agree with the idea of keeping the API simple. But I also think that it's good to have the compile method accessible to the user. If we want to change the parameters of loss or optimizer, like the learning rate, or maybe the optimizer itself, we can just call the compile method to update them in the model, without affecting the model weights. I generally try a few learning rates on sample data before running the full-scale experiment, to get a feel of how long it would take to train. So I thought it would be nice to have it as a method. ☺️

We could do something like this

class CNN(object):
    def __init__(self, loss=None, optimizer='adam')
        self.model.compile(loss=loss, optimizer=optimizer)

    def compile(self, optimizer='adam', loss=None):
        self.model.compile(optimizer=optimizer, loss=loss)
  1. Concerning load_weights/save_weights, I think it is good, and you should keep these methods.

👍

  1. I am not sure what is the difference between predict and predict_on_batch. Can you tell me more? Can it just be an option of predict?

predict is for inputs with many samples. It internally processes the data batch by batch. For inputs with few samples predict has overhead and is slower than predict_on_batch. This StackOverflow thread clears up the differences and also has some benchmarks.

From the TensorFlow docs

Computation is done in batches. This method is designed for performance in large scale inputs. For small amount of inputs that fit in one batch, directly using call is recommended for faster execution, e.g., model(x), or model(x, training=False)

If you think predict and predict_on_batch would be confusing to the user, we can have a __call__ method calling the predict_on_batch method. Then inference usage would be simple

model = CNN()
moved = model({'moving': moving, 'static': static})
  1. Concerning the metric. Do you think you will add more methods to these classes? If not, we do not need a class and you can keep them as a function with a decorator. A class with only init and call methods is a sign that it can be simpler.

Changed the metrics from classes to functions with decorators. 👍

  1. There a lot of code.... too much code... it will be hard to review. Please, consider 1 PR = 1 model. You should avoid adding 3-4 models on the same PR. Let's focus on 1 model and make sure that we all agree with the API.

Removed VoxelMorph models. 👍

@skoudoro
Copy link
Member

We could do something like this

class CNN(object):
 def __init__(self, loss=None, optimizer='adam')
     self.model.compile(loss=loss, optimizer=optimizer)

 def compile(self, optimizer='adam', loss=None):
     self.model.compile(optimizer=optimizer, loss=loss)

ok, a great hybrid solution, I am fine with that 👍.

predict is for inputs with many samples. It internally processes the data batch by batch. For inputs with few samples predict has overhead and is slower than predict_on_batch. This StackOverflow thread clears up the differences and also has some benchmarks.

Ok, thanks for the reference, I will read more about it.

If you think predict and predict_on_batch would be confusing to the user, we can have a call method calling the predict_on_batch method. Then inference usage would be simple

Yes, I think they are confusing. It needs more thinking and we need to be explicit so I would not recommend using __call__ here.

Removed VoxelMorph models. 👍

Thank you for the cleaning!

raise ImportError('Please upgrade to TensorFlow 2+')


def regular_grid_2d(height, width):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing docstring

Copy link
Contributor

@BramshQamar BramshQamar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work @sarathknv !

This PR needs documentation.

dipy/nn/registration/_utils.py Show resolved Hide resolved
__all__ = ["SimpleFCN2d"]


class SimpleFCN2d(object):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All functions in this class missing docstring

@BramshQamar
Copy link
Contributor

I've refactored my code a little. @BramshQamar @skoudoro @ShreyasFadnavis @Garyfallidis your thoughts on this.

Usage:

from dipy.nn.registration import SimpleFCN2d
from dipy.nn.metrics import NormalizedCrossCorrelation, MeanSquaredError


ncc = NormalizedCrossCorrelation()

model = SimpleFCN2d()
model.compile(optimizer='adam', loss=ncc)

# train
model.fit(x_train, y_train, validation_split=0.2, batch_size=8, epochs=10)

# evaluate on test set
model.evaluate(x_test, y_test, batch_size=8)

# save weights
model.save_weights('./weights_fcn2d.h5')

# load_weights
model.load_weights('weights_fcn2d.h5')

# for single/batch input
moved = model.predict_on_batch({'moving': moving, 'static': static})

If this API is okay, I'll go ahead and clean it and add the test file.

It would be useful to have documentation/tutorial on how to use the API. Something similar to the usage example you have provided here.

Here are the existing tutorials in DIPY.

Tutorial can be added later in a different PR.

@finalelement
Copy link

@ShreyasFadnavis Thanks for the tag. Will check out the code.

@pep8speaks
Copy link

pep8speaks commented Jun 26, 2020

Hello @sarathknv, Thank you for updating !

Line 138:42: W292 no newline at end of file

Line 60:76: W291 trailing whitespace
Line 218:76: W291 trailing whitespace
Line 226:79: W291 trailing whitespace

Comment last updated at 2020-08-21 16:32:45 UTC

Copy link
Contributor

@arokem arokem left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work. Here are a few comments/suggestions.

class SimpleFCN2d(object):
def __init__(self, input_shape=(32, 32)):
in_channels = 1
class FCN2d(object):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does "FCN" stand for? Could we spell this out in the class name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really sorry for the late reply. I can be active from now on.

This was some custom architecture that I wrote for the POC. I’ve added a standard model (U-Net) and will be removing fcn.py.

dipy/nn/registration/unet.py

dipy/nn/registration/fcn.py Show resolved Hide resolved
"""


class DataLoader(tf.keras.utils.Sequence):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something that would be worth implementing as library code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s definitely useful, but we need to handle several cases if we are planning to have this in the library.
Current location of the class dipy/nn/registration/dataloader.py

The function of the dataloader object is to feed batches of processed data to the model. The user needs to pass the entire dataset as input either as arrays (shape (num_images, H, W, …, num_channels)) or as a list of file paths, and the dataloader object performs some preprocessing (data transformations, adding dimensions, changing data types) on batches of data and feeds it to the model. Currently, this code only supports NumPy arrays as input, which will work for small datasets that can be loaded into the memory, but a general use case would be file paths as input.

Some cases that we need to handle:

  1. NumPy arrays as input
    1. Single static image and multiple moving images - All the moving images will be registered to this static image (moving shape (N, H, W, ..., C), static shape (H, W, ..., C))
    2. Multiple static and moving images - every moving image has a different static image (moving shape (N, H, W, C), static shape (N, H, W, C))
  2. Lists of file paths as input
    1. Single static image and multiple moving images
    2. Multiple static and moving images
    3. Different file extensions .npy, .h5, .nii, …

the moving images.
"""

moved = moved.squeeze(axis=-1) # Remove the channel dim.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this little block of code be useful to implement as module code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only needed for single-channel images. Since most images are single-channel I also think it is useful to implement as library code. Do you mean something like this? Add the squeeze option in the predict method

moved = model.predict(data, squeeze=True)

ax.imshow(images_list[j][i], cmap='gray')
plt.tight_layout()
plt.show()
plt.savefig('sample_results.png')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This figure is empty for me when I run this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use fig.savefig instead?

"""

batch_size = 32
epochs = 100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like performance converges after much fewer epochs. Maybe set this to 25 for the purpose of the example and then add a comment that this could be increased to a larger number for more accurate results?

@Garyfallidis
Copy link
Contributor

@sarathknv please enable automatic trailing whitespace removal in your IDE. It will help you resolve pep8 issues faster.

@Garyfallidis
Copy link
Contributor

@sarathknv, @BramshQamar, @ShreyasFadnavis, @skoudoro what is the progress here? Does this need a rebase?

@Paluck1Arora2
Copy link

Can anyone please provide the dataset of this code "https://github.com/sarathknv/voxelmorph"
data = np.load('drive/My Drive/DIPY/brain_affine_128.npy')
static = np.load('drive/My Drive/DIPY/static.npy')
These lines are mention in voxelmorph.ipynb file

data = np.load('data/t1_moving_128.npy')
static = np.load('data/t1_static_128.npy')
These lines are mention in main.py

@jhlegarreta
Copy link
Contributor

jhlegarreta commented Jan 9, 2023

#2137 (comment) @Paluck1Arora2 Please post questions in DIPY's gitter room, mailing list or the DIPY wiki:
https://dipy.org/support/

Also, please avoid asking the same question as a comment and opening an issue at the same time:
#2708

Note that issues should primarily be opened to either raise attention about bugs or to propose/ask about new features.

Thanks.

@skoudoro skoudoro force-pushed the master branch 7 times, most recently from 1419292 to ca6268a Compare December 8, 2023 22:25
@skoudoro skoudoro force-pushed the master branch 3 times, most recently from 5935e1e to 765963e Compare January 24, 2024 19:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

10 participants