Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

weighted sum in masked loss #416

Open
nehSgnaiL opened this issue May 14, 2024 · 1 comment
Open

weighted sum in masked loss #416

nehSgnaiL opened this issue May 14, 2024 · 1 comment

Comments

@nehSgnaiL
Copy link

Hi,

Thanks for the remarkable work.

I would like to know more about the operations in defined loss. Since the mask has been normalized by mask /= torch.mean(mask), should we use the sum operation torch.sum(loss) rather than the mean operation torch.mean(loss) in returning loss?

def masked_mse_torch(preds, labels, null_val=np.nan):
labels[torch.abs(labels) < 1e-4] = 0
if np.isnan(null_val):
mask = ~torch.isnan(labels)
else:
mask = labels.ne(null_val)
mask = mask.float()
mask /= torch.mean(mask)
mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)
loss = torch.square(torch.sub(preds, labels))
loss = loss * mask
loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)
return torch.mean(loss)

I am not sure what I understand is right due to my limited knowledge. If you could respond, that would be greatly appreciated.

@nehSgnaiL nehSgnaiL changed the title weighted sum is masked loss weighted sum in masked loss May 14, 2024
@nehSgnaiL
Copy link
Author

The codes mask /= torch.mean(mask) and return torch.mean(loss) work properly. The outcome appears to be correct, but I find it difficult to read.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant