Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metric with multiple input runs in an unexpected way. #2940

Open
lyhyl opened this issue May 9, 2023 · 1 comment
Open

Metric with multiple input runs in an unexpected way. #2940

lyhyl opened this issue May 9, 2023 · 1 comment

Comments

@lyhyl
Copy link

lyhyl commented May 9, 2023

❓ Questions/Help/Support

My customized loss requires two pairs of input:

class MyLoss(nn.Module):
    def __init__(self, ca: float = 1.0, cb: float = 1.0) -> None:
        super().__init__()
        self.ca = ca
        self.cb = cb

    def forward(self, y_pred: Tuple[torch.Tensor, torch.Tensor], y_true: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        a_true, b_true = y_true
        a_pred, b_pred = y_pred
        return self.ca * F.mse_loss(a_pred, a_true) + self.cb * F.cross_entropy(b_pred, b_true)

When I try to log the loss with Loss metric:

loss = MyLoss(0.5, 1.0)
metrics = {
    "Loss": Loss(loss)
}
train_evaluator = create_supervised_evaluator(model, metrics, device, prepare_batch=prepare_batch)

It will crash on line:

self.update((tensor_o1, tensor_o2))

because it treats all inputs as independent pair of y_pred and y, which is not what MyLoss need.

I dug into the source code I found #2055 introduces a new feature, which causes this issue.
So, what are the best practices for dealing with multiple input losses?

@lyhyl lyhyl added the question label May 9, 2023
@lyhyl lyhyl changed the title Loss metric with multiple input runs in an unexpected way. Metric with multiple input runs in an unexpected way. May 9, 2023
@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 9, 2023

@lyhyl thanks for reporting this issue!

RIght now a workaround could be to replace the structure Tuple[torch.Tensor, torch.Tensor] by something non-iterable to prevent unrolling by Metric.

import torch
import torch.nn as nn
import torch.nn.functional as F

from ignite.engine import create_supervised_evaluator
from ignite.metrics import Loss


class TargetsPair:
    a: torch.Tensor
    b: torch.Tensor

    def __init__(self, a, b):
        self.a = a
        self.b = b
    
    def __len__(self):
        return len(self.a)


class MyLoss(nn.Module):
    def __init__(self, ca: float = 1.0, cb: float = 1.0) -> None:
        super().__init__()
        self.ca = ca
        self.cb = cb

    def forward(self, y_pred: TargetsPair, y_true: TargetsPair) -> torch.Tensor:
        a_true, b_true = y_true.a, y_true.b
        a_pred, b_pred = y_pred.a, y_pred.b
        return self.ca * F.mse_loss(a_pred, a_true) + self.cb * F.cross_entropy(b_pred, b_true)


def prepare_batch(batch, device, non_blocking):
    return torch.rand(4, 1), (torch.rand(4, 1), torch.rand(4, 2))


class MyModel(nn.Module):
    
    def forward(self, x):
        return torch.rand(4, 1), torch.rand(4, 2)


model = MyModel()


def output_transform(output):
    (a_pred, b_pred), (a_true, b_true) = output
    return TargetsPair(a_pred, b_pred), TargetsPair(a_true, b_true)


device = "cpu"
loss = MyLoss(0.5, 1.0)
metrics = {
    "Loss": Loss(loss, output_transform=output_transform)
}
train_evaluator = create_supervised_evaluator(model, metrics, device, prepare_batch=prepare_batch)


data = range(10)
train_evaluator.run(data)
train_evaluator.state.metrics["Loss"]

In future, we may introduce a flag into Metric class to skip output unrolling and feed the output into update function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants