Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Hugging Face Integration #531

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open

Conversation

dcwil
Copy link
Collaborator

@dcwil dcwil commented Sep 11, 2023

#480
Whats been done:

  • Set up basic hugging face workflow (vaguely following the torchvision style)
  • uploaded some pretrained nets to my personal repo (ShallowFBCSP, BNCI2014001, same params except 200 epochs)
  • Extended trialwise decoding example to demonstrate the pretrained network

What would likely need to be added:

  • more models/datasets (see below)
  • a specific example for using pretrained networks
  • create (and use) braindecode HF repo
  • allow users to upload networks trained with braindecode to their own HF repo?
  • initialize_model variant or toggle for initialising a skorch classifier?

To move forwards need to know if:

  • we're happy with how the weights will be loaded (from a user perspective, and also internally)
  • exactly which models should be run, on what datasets and with which parameters etc (Do we need to do a hyperparameter exploration first?)



@dataclass
class Weights:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put a reference to torchvision in the docstring of this class to say we mimic their behavior?



@dataclass
class Weights:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put a reference to torchvision in the docstring of this class to say we mimic their behavior?

path: str


class WeightsEnum(Enum):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Comment on lines 16 to 18
Initialize and return a model specified by the `name` parameter. If `dataset_name` and
`subject_id` are provided, pretrained weights associated with those parameters will be downloaded
and used for initialization; otherwise, random initialization will be performed.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I like this. This forces models to be trained on one subject only. Loading pre-trained weights will be especially useful for general models that were trained on large/multiple datasets. And even on the same dataset, we can provide multiple pre-trained weights that were obtained with different training algorithms.

I think the weights name should not have constraints.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could say in the case of passing a dataset name str, (or list of strs for multiple datasets) and subject_id=None implies you want to download a model that has been trained for all subjects on those given datasets?

For different training algorithms we could add another variable. It does then start to get slightly messy however, and I guess the alternative is we force the user to type the pretrained model name explicitly e.g:
model = initialize_model(name='ShallowFBCSPNet', weights=ShallowFBCSPNet['SomeDataset_S1_V2'])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should just replace the parameters subject_id and dataset by weights_id and provide naming guidelines for weights_id with a few examples.
@bruAristimunha what do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree - best would be to define convention but not restrict people to define model on subject_is anddataset

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree too

Comment on lines 6 to 9
MODELS_AND_WEIGHTS = {
"shallowfbcspnet": {"model": ShallowFBCSPNet, "weights": ShallowFBCSPNetWeights}
# Other models go here
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can follow the convention described in #524 and put that dict in models.utils

Also, why do you put the model name in lower case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, lower case to make it easier to type the correct model name (the name gets .lower() in initialize_model), but we can probably trust the user to do this correctly.

Comment on lines 347 to 352
model = initialize_model('ShallowFBCSPNet', dataset_name=dataset_name, subject_id=subject_id)

clf = EEGClassifier(
model,
device=device,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should find a way to be able to do directly something like this:

Suggested change
model = initialize_model('ShallowFBCSPNet', dataset_name=dataset_name, subject_id=subject_id)
clf = EEGClassifier(
model,
device=device,
)
clf = EEGClassifier.from_pretrained(
ShallowFBCSPNet,
device=device,
weights_name='...',
)
# downloads the model args,
# instanciates the EEGClassigier with these args,
# initializes the EEGClassifier (including the model),
# and loads the pre-trained weights in the model.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we can also check other libraries what API they use for this, e.g. does skorch have some way already? Also would we even need separate from_pretrained parameter? Torchvision for example has weights as another parameter in the constructor at https://pytorch.org/vision/stable/models.html

@sliwy
Copy link
Collaborator

sliwy commented Sep 11, 2023

No clue if it should be done in this PR but if we plan to share models and load them easy into python maybe we should also think about a better format than .pkl (for example safer). Maybe safetensors https://github.com/huggingface/safetensors#yet-another-format-?

@PierreGtch
Copy link
Collaborator

@sliwy I agree but probably for another PR

@bruAristimunha
Copy link
Collaborator

Hey @dcwil, do you need some help with the new extra requirement?

@codecov
Copy link

codecov bot commented Sep 27, 2023

Codecov Report

Merging #531 (c85b47f) into master (9cdde83) will decrease coverage by 1.28%.
Report is 4 commits behind head on master.
The diff coverage is 62.90%.

@@            Coverage Diff             @@
##           master     #531      +/-   ##
==========================================
- Coverage   84.72%   83.45%   -1.28%     
==========================================
  Files          63       65       +2     
  Lines        4741     4878     +137     
==========================================
+ Hits         4017     4071      +54     
- Misses        724      807      +83     

@bruAristimunha
Copy link
Collaborator

Hi @dcwil,

In the CI, it seems that the generation of documentation is failing due to something related to the new functionality.

@dcwil
Copy link
Collaborator Author

dcwil commented Sep 29, 2023

Hi @dcwil,

In the CI, it seems that the generation of documentation is failing due to something related to the new functionality.

I haven't had the chance to look into it very deeply yet, but it could be that the weights I uploaded in Paris are now out of date with the code somehow?

@robintibor
Copy link
Contributor

Do you have any time to look into this @dcwil ?

@robintibor robintibor mentioned this pull request Nov 6, 2023
@dcwil
Copy link
Collaborator Author

dcwil commented Nov 6, 2023

Do you have any time to look into this @dcwil ?

Not right now I'm afraid - it's on my to-do list but has been superseded by a couple of other things

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants