Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create_lr_scheduler_with_warmup does not change init_lr to proper value #2441

Open
sadra-barikbin opened this issue Jan 24, 2022 · 10 comments
Open

Comments

@sadra-barikbin
Copy link
Collaborator

Hi,

In order to get expected sequence of lrs from create_lr_scheduler_with_warmup's scheduler, one must not attach it to engine on event EPOCH_COMPLETED because it produces the lr passed to optimizer's constructor at the beginning and then warmup lrs. This hurts warmup procedure. As a workaround, one could use event EPOCH_STARTED but it might not be a good solution.

It seems there should be something like below in line 1017 of param_scheduler.py within the for loop .

param_group['lr'] = warmup_start_value

To reproduce current behaviour:

param = nn.Parameter(torch.Tensor([3.]))
optimizer = torch.optim.SGD([param], lr=1e-3)
scheduler = StepLR(optimizer, 3)
with_warmup_scheduler = create_lr_scheduler_with_warmup(scheduler, warmup_start_value=1e-5, warmup_duration=3)

def process_func(e,b):
  param.grad = torch.Tensor([1.])
  optimizer.step()
trainer = Engine(process_func)
@trainer.on(Events.EPOCH_COMPLETED)
def _():
  print(op.param_groups[0]['lr'])
trainer.add_event_handler(Events.EPOCH_COMPLETED, with_warmup_scheduler)

output:

0.001
1e-05
0.000505
0.001
0.001
0.001
0.0001
0.0001
0.0001
1e-05
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jan 24, 2022

@sadra-barikbin have you seen the example from docs: https://pytorch.org/ignite/generated/ignite.handlers.param_scheduler.create_lr_scheduler_with_warmup.html#ignite.handlers.param_scheduler.create_lr_scheduler_with_warmup

I'm not sure if it makes sense to perform warm-up on epochs.

But anyway, if you check our docs, almost all ignite.handlers.param_scheduler are attached to ITERATION_STARTED or EPOCH_STARTED.

@sadra-barikbin
Copy link
Collaborator Author

Um, you're right! People do warm-up on iterations. The question that arises is that how to do warm-up on iterations and then do normal scheduling on epochs by use of the create_lr_scheduler_with_warmup's scheduler? At this moment, it does only listen to epoch events or iteration events.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jan 25, 2022

@sadra-barikbin you can simply schedule your post-warm-up epoch-wise scheduling using iterations if possible.
Otherwise, you can try to combine events as below:

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR

from ignite.engine import Engine, Events
from ignite.handlers import create_lr_scheduler_with_warmup


def train_step(e, b):
    print(trainer.state.epoch, trainer.state.iteration, " | ", optimizer.param_groups[0]["lr"])

    
trainer = Engine(train_step)
optimizer = optim.SGD([torch.tensor([0.1])], lr=0.1234)


torch_lr_scheduler = ExponentialLR(optimizer=optimizer, gamma=0.5)

data = [0] * 8
epoch_length = len(data)
warmup_duration = 5
scheduler = create_lr_scheduler_with_warmup(torch_lr_scheduler,
                                            warmup_start_value=0.0,
                                            warmup_duration=warmup_duration)

# Trigger scheduler on iteration_started events before reaching warmup_duration
combined_events = Events.ITERATION_STARTED(event_filter=lambda _, __: trainer.state.iteration <= warmup_duration)
# Trigger scheduler on epoch_started events after the warm-up. Epochs are 1-based, thus we do 1 + 
combined_events |= Events.EPOCH_STARTED(event_filter=lambda _, __: trainer.state.epoch > 1 + warmup_duration / epoch_length)
trainer.add_event_handler(combined_events, scheduler)
   
trainer.run(data, max_epochs=10)
> 
1 1  |  0.0
1 2  |  0.03085
1 3  |  0.0617
1 4  |  0.09255
1 5  |  0.1234
1 6  |  0.1234
1 7  |  0.1234
1 8  |  0.1234
2 9  |  0.0617
2 10  |  0.0617
2 11  |  0.0617
2 12  |  0.0617
2 13  |  0.0617
2 14  |  0.0617
2 15  |  0.0617
2 16  |  0.0617
3 17  |  0.03085
3 18  |  0.03085
3 19  |  0.03085
3 20  |  0.03085
3 21  |  0.03085
3 22  |  0.03085
3 23  |  0.03085
3 24  |  0.03085
4 25  |  0.015425
4 26  |  0.015425
4 27  |  0.015425
4 28  |  0.015425
4 29  |  0.015425
4 30  |  0.015425

By the way, our docs on create_lr_scheduler_with_warmup is incorrect, cc @sdesrozis .
We should trigger scheduler on ITERATION_STARTED instead of ITERATION_COMPLETED if we would like to avoid the first iteration to take optimizer's default value.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jan 27, 2022

Docs issue is fixed in #2442
Let's just add another example with mixing events as in the code above: #2441 (comment)

@sadra-barikbin
Copy link
Collaborator Author

Shall I add the example to docs?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jan 29, 2022

Yes, this could be helpful. Thanks!

@trsvchn
Copy link
Collaborator

trsvchn commented Feb 15, 2022

Furthermore, as discussed, another option is to follow the sklearn's practice we can put this example into how-to-guides and cross link it in a docstrings.

@sadra-barikbin
Copy link
Collaborator Author

@trsvchn By "how-to-guides" you mean a place like this?

If yes, so I add the example to this page and then add reference to it in create_lr_scheduler_with_warmup's page. Am I right?

@trsvchn
Copy link
Collaborator

trsvchn commented Feb 21, 2022

@sadra-barikbin we meant our new how-to-guides page, here

@trsvchn
Copy link
Collaborator

trsvchn commented Feb 21, 2022

@sadra-barikbin so basically you add example to this examples and it will rendered on the main website (the new one) and we can reference it then.

please check the contributing guide for the examples and do not hesitate to ask for the help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants