-
Notifications
You must be signed in to change notification settings - Fork 512
/
__init__.py
153 lines (120 loc) · 5.82 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
import argparse
from typing import Any, Dict, List
import torch.nn
from corenet.optims.base_optim import BaseOptim
from corenet.utils import logger
from corenet.utils.common_utils import unwrap_model_fn
from corenet.utils.registry import Registry
OPTIM_REGISTRY = Registry(
registry_name="optimizer_registry",
base_class=BaseOptim,
lazy_load_dirs=["corenet/optims"],
internal_dirs=["corenet/internal", "corenet/internal/projects/*"],
)
def check_trainable_parameters(
model: torch.nn.Module, model_params: List[Dict[str, Any]]
) -> None:
"""Helper function to check if any model parameters w/ gradients are not part of model_params.
'get_trainable_parameters' is a custom function. However, there may be instances where not all parameters
are passed to the optimizer, potentially causing training instabilities and yielding undesired results.
This function compares the named parameters obtained with 'get_trainable_parameters' with
PyTorch's function 'model.named_parameters()'. This helps mitigate potential issues during the training phase.
Args:
model: An instance of torch.nn.Module.
model_params: Model parameters computed using 'get_trainable_parameters' function.
"""
# get model parameter names
model_trainable_params = []
# Activation checkpointing, enabled using --model.activation-checkpointing, adds a
# prefix '_checkpoint_wrapped_module' to sub-module name.
# If prefix is present in the parameter name, remove it
act_ckpt_wrapped_module_name = "._checkpoint_wrapped_module"
for p_name, param in model.named_parameters():
if param.requires_grad:
p_name = p_name.replace(act_ckpt_wrapped_module_name, "")
model_trainable_params.append(p_name)
initialized_params = []
for param_info in model_params:
if not isinstance(param_info, Dict):
logger.error(
"Expected format is a Dict with three keys: params, weight_decay, param_names"
)
if not {"params", "weight_decay", "param_names"}.issubset(param_info.keys()):
logger.error(
"Parameter dict should have three keys: params, weight_decay, param_names"
)
param_names = param_info["param_names"]
if isinstance(param_names, List):
param_names = [
param_name.replace(act_ckpt_wrapped_module_name, "")
for param_name in param_names
]
initialized_params.extend(param_names)
elif isinstance(param_names, str):
param_names = param_names.replace(act_ckpt_wrapped_module_name, "")
initialized_params.append(param_names)
else:
raise NotImplementedError
uninitialized_params = set(model_trainable_params) ^ set(initialized_params)
if len(uninitialized_params) > 0:
logger.error(
"Following parameters are defined in the model, but won't be part of optimizer. "
"Please check get_trainable_parameters function. "
"Use --optim.bypass-parameters-check flag to bypass this check. "
"Parameter list = {}".format(uninitialized_params)
)
def remove_param_name_key(model_params: List) -> None:
"""Helper function to remove param_names key from model_params.
Optimizer only takes params and weight decay as keys. However, 'get_trainable_parameters' return three keys:
(1) params, (2) weight_decay and (3) param_names. The 'param_names' key is used for sanity checking in
'check_trainable_parameters' function, and is removed inside this function so that model_params can be passed
to optimzier.
Args:
model_params: A list of dictionaries, where each dictionary element is expected to have
three keys: (1) params: an instance of torch.nn.Parameter, (2) weight decay, and (3) param_names.
...note:
This function should be called after 'check_trainable_parameters' function.
"""
for param_info in model_params:
if not isinstance(param_info, Dict):
logger.error(
"Expected format is a Dict with three keys: params, weight_decay, param_names"
)
if not {"params", "weight_decay", "param_names"}.issubset(param_info.keys()):
logger.error(
"Parameter dict should have three keys: params, weight_decay, param_names"
)
param_info.pop("param_names")
def build_optimizer(model: torch.nn.Module, opts, *args, **kwargs) -> BaseOptim:
"""Helper function to build an optimizer
Args:
model: A model
opts: command-line arguments
Returns:
An instance of BaseOptim
"""
optim_name = getattr(opts, "optim.name")
weight_decay = getattr(opts, "optim.weight_decay")
no_decay_bn_filter_bias = getattr(opts, "optim.no_decay_bn_filter_bias")
unwrapped_model = unwrap_model_fn(model)
model_params, lr_mult = unwrapped_model.get_trainable_parameters(
weight_decay=weight_decay,
no_decay_bn_filter_bias=no_decay_bn_filter_bias,
*args,
**kwargs
)
# check to ensure that all trainable model parameters are passed to the model
if not getattr(opts, "optim.bypass_parameters_check", False):
check_trainable_parameters(model=unwrapped_model, model_params=model_params)
remove_param_name_key(model_params=model_params)
# set the learning rate multiplier for each parameter
setattr(opts, "optim.lr_multipliers", lr_mult)
return OPTIM_REGISTRY[optim_name](opts, model_params, *args, **kwargs)
def arguments_optimizer(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = BaseOptim.add_arguments(parser)
parser = OPTIM_REGISTRY.all_arguments(parser)
return parser