-
Notifications
You must be signed in to change notification settings - Fork 429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deep-learning-based registration example #2137
base: master
Are you sure you want to change the base?
Conversation
I've refactored my code a little. @BramshQamar @skoudoro @ShreyasFadnavis @Garyfallidis your thoughts on this. Usage:
If this API is okay, I'll go ahead and clean it and add the test file. |
Codecov Report
@@ Coverage Diff @@
## master #2137 +/- ##
==========================================
- Coverage 91.37% 89.25% -2.13%
==========================================
Files 251 259 +8
Lines 32514 32864 +350
Branches 3423 3444 +21
==========================================
- Hits 29711 29333 -378
- Misses 2057 2792 +735
+ Partials 746 739 -7
|
Hi @sarathknv ... I quite like this API design. Would like some other people's opinions on this! @skoudoro did you get a chance to see if we can unify the metrics that we already have in DIPY with TF? @finalelement can you take a look too? Does this abide by your PR? Any other stuff you'd like the API to have/ we missed out on? |
Hi @sarathknv, Thank you for this update!
Thank you |
Thanks for the feedback @ShreyasFadnavis and @skoudoro
I agree with the idea of keeping the API simple. But I also think that it's good to have the We could do something like this
👍
From the TensorFlow docs
If you think
Changed the metrics from classes to functions with decorators. 👍
Removed VoxelMorph models. 👍 |
class CNN(object):
def __init__(self, loss=None, optimizer='adam')
self.model.compile(loss=loss, optimizer=optimizer)
def compile(self, optimizer='adam', loss=None):
self.model.compile(optimizer=optimizer, loss=loss) ok, a great hybrid solution, I am fine with that 👍.
Ok, thanks for the reference, I will read more about it.
Yes, I think they are confusing. It needs more thinking and we need to be explicit so I would not recommend using
Thank you for the cleaning! |
raise ImportError('Please upgrade to TensorFlow 2+') | ||
|
||
|
||
def regular_grid_2d(height, width): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work @sarathknv !
This PR needs documentation.
dipy/nn/registration/fcn.py
Outdated
__all__ = ["SimpleFCN2d"] | ||
|
||
|
||
class SimpleFCN2d(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All functions in this class missing docstring
It would be useful to have documentation/tutorial on how to use the API. Something similar to the usage example you have provided here. Here are the existing tutorials in DIPY. Tutorial can be added later in a different PR. |
@ShreyasFadnavis Thanks for the tag. Will check out the code. |
Hello @sarathknv, Thank you for updating !
Comment last updated at 2020-08-21 16:32:45 UTC |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work. Here are a few comments/suggestions.
class SimpleFCN2d(object): | ||
def __init__(self, input_shape=(32, 32)): | ||
in_channels = 1 | ||
class FCN2d(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does "FCN" stand for? Could we spell this out in the class name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really sorry for the late reply. I can be active from now on.
This was some custom architecture that I wrote for the POC. I’ve added a standard model (U-Net) and will be removing fcn.py.
""" | ||
|
||
|
||
class DataLoader(tf.keras.utils.Sequence): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this something that would be worth implementing as library code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It’s definitely useful, but we need to handle several cases if we are planning to have this in the library.
Current location of the class dipy/nn/registration/dataloader.py
The function of the dataloader object is to feed batches of processed data to the model. The user needs to pass the entire dataset as input either as arrays (shape (num_images, H, W, …, num_channels)) or as a list of file paths, and the dataloader object performs some preprocessing (data transformations, adding dimensions, changing data types) on batches of data and feeds it to the model. Currently, this code only supports NumPy arrays as input, which will work for small datasets that can be loaded into the memory, but a general use case would be file paths as input.
Some cases that we need to handle:
- NumPy arrays as input
- Single static image and multiple moving images - All the moving images will be registered to this static image (moving shape (N, H, W, ..., C), static shape (H, W, ..., C))
- Multiple static and moving images - every moving image has a different static image (moving shape (N, H, W, C), static shape (N, H, W, C))
- Lists of file paths as input
- Single static image and multiple moving images
- Multiple static and moving images
- Different file extensions .npy, .h5, .nii, …
the moving images. | ||
""" | ||
|
||
moved = moved.squeeze(axis=-1) # Remove the channel dim. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this little block of code be useful to implement as module code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is only needed for single-channel images. Since most images are single-channel I also think it is useful to implement as library code. Do you mean something like this? Add the squeeze option in the predict method
moved = model.predict(data, squeeze=True)
ax.imshow(images_list[j][i], cmap='gray') | ||
plt.tight_layout() | ||
plt.show() | ||
plt.savefig('sample_results.png') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This figure is empty for me when I run this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe use fig.savefig
instead?
""" | ||
|
||
batch_size = 32 | ||
epochs = 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like performance converges after much fewer epochs. Maybe set this to 25 for the purpose of the example and then add a comment that this could be increased to a larger number for more accurate results?
@sarathknv please enable automatic trailing whitespace removal in your IDE. It will help you resolve pep8 issues faster. |
@sarathknv, @BramshQamar, @ShreyasFadnavis, @skoudoro what is the progress here? Does this need a rebase? |
Can anyone please provide the dataset of this code "https://github.com/sarathknv/voxelmorph" data = np.load('data/t1_moving_128.npy') |
#2137 (comment) @Paluck1Arora2 Please post questions in DIPY's gitter room, mailing list or the DIPY wiki: Also, please avoid asking the same question as a comment and opening an issue at the same time: Note that issues should primarily be opened to either raise attention about bugs or to propose/ask about new features. Thanks. |
1419292
to
ca6268a
Compare
5935e1e
to
765963e
Compare
A simple example of deep-learning-based non-rigid image registration with the MNIST dataset, taking ideas from
Framework
An encoder-decoder network is trained to output a dense displacement field that is used to warp the moving image to align with the fixed image. I've used MNIST images of one class, same as that of the fixed image, for training and testing.
Sample results
Colab notebook
https://colab.research.google.com/drive/1fTzz1aT2sb8oAXRO1-dr6O_IR6dof36e
Training takes about 5 minutes on GPU.