-
Notifications
You must be signed in to change notification settings - Fork 170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Hugging Face Integration #531
base: master
Are you sure you want to change the base?
Conversation
|
||
|
||
@dataclass | ||
class Weights: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you put a reference to torchvision in the docstring of this class to say we mimic their behavior?
|
||
|
||
@dataclass | ||
class Weights: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you put a reference to torchvision in the docstring of this class to say we mimic their behavior?
path: str | ||
|
||
|
||
class WeightsEnum(Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
braindecode/pretrained.py
Outdated
Initialize and return a model specified by the `name` parameter. If `dataset_name` and | ||
`subject_id` are provided, pretrained weights associated with those parameters will be downloaded | ||
and used for initialization; otherwise, random initialization will be performed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I like this. This forces models to be trained on one subject only. Loading pre-trained weights will be especially useful for general models that were trained on large/multiple datasets. And even on the same dataset, we can provide multiple pre-trained weights that were obtained with different training algorithms.
I think the weights name should not have constraints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could say in the case of passing a dataset name str, (or list of strs for multiple datasets) and subject_id=None
implies you want to download a model that has been trained for all subjects on those given datasets?
For different training algorithms we could add another variable. It does then start to get slightly messy however, and I guess the alternative is we force the user to type the pretrained model name explicitly e.g:
model = initialize_model(name='ShallowFBCSPNet', weights=ShallowFBCSPNet['SomeDataset_S1_V2'])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should just replace the parameters subject_id
and dataset
by weights_id
and provide naming guidelines for weights_id
with a few examples.
@bruAristimunha what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree - best would be to define convention but not restrict people to define model on subject_is
anddataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree too
braindecode/pretrained.py
Outdated
MODELS_AND_WEIGHTS = { | ||
"shallowfbcspnet": {"model": ShallowFBCSPNet, "weights": ShallowFBCSPNetWeights} | ||
# Other models go here | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can follow the convention described in #524 and put that dict in models.utils
Also, why do you put the model name in lower case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah, lower case to make it easier to type the correct model name (the name gets .lower()
in initialize_model
), but we can probably trust the user to do this correctly.
model = initialize_model('ShallowFBCSPNet', dataset_name=dataset_name, subject_id=subject_id) | ||
|
||
clf = EEGClassifier( | ||
model, | ||
device=device, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should find a way to be able to do directly something like this:
model = initialize_model('ShallowFBCSPNet', dataset_name=dataset_name, subject_id=subject_id) | |
clf = EEGClassifier( | |
model, | |
device=device, | |
) | |
clf = EEGClassifier.from_pretrained( | |
ShallowFBCSPNet, | |
device=device, | |
weights_name='...', | |
) | |
# downloads the model args, | |
# instanciates the EEGClassigier with these args, | |
# initializes the EEGClassifier (including the model), | |
# and loads the pre-trained weights in the model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah we can also check other libraries what API they use for this, e.g. does skorch have some way already? Also would we even need separate from_pretrained
parameter? Torchvision for example has weights
as another parameter in the constructor at https://pytorch.org/vision/stable/models.html
No clue if it should be done in this PR but if we plan to share models and load them easy into python maybe we should also think about a better format than |
@sliwy I agree but probably for another PR |
Hey @dcwil, do you need some help with the new extra requirement? |
# Conflicts: # braindecode/models/shallow_fbcsp.py
Codecov Report
@@ Coverage Diff @@
## master #531 +/- ##
==========================================
- Coverage 84.72% 83.45% -1.28%
==========================================
Files 63 65 +2
Lines 4741 4878 +137
==========================================
+ Hits 4017 4071 +54
- Misses 724 807 +83 |
Hi @dcwil, In the CI, it seems that the generation of documentation is failing due to something related to the new functionality. |
I haven't had the chance to look into it very deeply yet, but it could be that the weights I uploaded in Paris are now out of date with the code somehow? |
Do you have any time to look into this @dcwil ? |
Not right now I'm afraid - it's on my to-do list but has been superseded by a couple of other things |
#480
Whats been done:
What would likely need to be added:
initialize_model
variant or toggle for initialising a skorch classifier?To move forwards need to know if: