Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The mlm loss computation in the function _get_batch_loss_bert seems wrong in d2l pytorch code #2582

Open
lyconghk opened this issue Jan 6, 2024 · 2 comments

Comments

@lyconghk
Copy link

lyconghk commented Jan 6, 2024

In my opinion, the BERT pretrain batch loss in the function _get_batch_loss_bert is not correct. The following is the detail:

The CrossEntropyLoss is initialized with default reduction 'mean',
loss = nn.CrossEntropyLoss()
In the function _get_batch_loss_bert, mlm_loss and nsp_loss used the same input instance loss for computation.
mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *mlm_weights_X.reshape(-1, 1)
Since the reduction='mean', the resultant tensor of 'loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) ' is a scalar tensor, it leads a problem for mlm loss computation by positionwise product with the input tensor mlm_weights_X.

@lyconghk lyconghk changed the title The mlm loss computation in the function _get_batch_loss_bert is wrong in d2l pytorch code The mlm loss computation in the function _get_batch_loss_bert seems wrong in d2l pytorch code Jan 6, 2024
@gab-chen
Copy link
Contributor

gab-chen commented Jan 28, 2024

Agree with you @lyconghk . Have you come up with any better solution to apply mlm_weights_X in mlm_l calculation?

The weight parameter of PyTorch CrossEntropyLoss does not seem to support mlm_weights_X in the way that the MXNet does. I guess that is why the PyTorch version of _get_batch_loss_bert calculate mlm_l in this way. It tries to reduce the impact of padded tokens to mlm_l, but it does not use mlm_weights_X in an correct way.

@lyconghk
Copy link
Author

lyconghk commented Jan 29, 2024

How about just use the package torch.nn import functional to calculate the two cross entropy loss of mlm and nsp?
And remove the input parameter loss in the function _get_batch_loss_ber.

from torch.nn import functional as F

mlm_l = F.cross_entropy(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1), reduction='none')

nsp_l = F.cross_entropy(nsp_Y_hat, nsp_Y, reduction='mean')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants