Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TODO]: Metric Improvements #817

Open
KnathanM opened this issue Apr 19, 2024 · 9 comments
Open

[TODO]: Metric Improvements #817

KnathanM opened this issue Apr 19, 2024 · 9 comments
Labels
todo add an item to the to-do list
Milestone

Comments

@KnathanM
Copy link
Contributor

I think we can decouple Metric from LossFunction. The forward method is already different because it does use task weights to reduce the loss. It also should have it's own init function because metrics don't take task weights. At that point there isn't anything that Metric needs to get from LossFunction.

We currently couple them because some metrics inherit _calc_unreduced_loss from their loss function counterpart, like BinaryMCCMetric. At the same time though, I don't know if these metrics are actually calculated correctly. It looks like _calc_unreduced_loss in BinaryMCCLoss is already reduced, in which case this function should actually be forward().

If that should be changed to forward() there is also the question of if mask needs to remain in signature of _calc_unreduced_loss as the foward method of LossFunction handles masking as part of the reduction.

All of this is to say, I think the more advanced metrics need some improvements, but that will take more time to discuss, so it isn't part of v2.0

@KnathanM KnathanM added the todo add an item to the to-do list label Apr 19, 2024
@KnathanM KnathanM added this to the v2.1.0 milestone Apr 19, 2024
@davidegraff
Copy link
Contributor

FWIW the original goal was to make it such that a LossFunction is a valid Metric but not all Metrics are valid LossFunctions. I don't think I really achieved that, but the motivation is to avoid duplicating code, i.e., we shouldn't need to define fully separate MSELoss and MSEMetric classes because they're the same thing. Of course, this comes with the caveat that, as originally defined, validation metrics were unscaled by sample weights and task weights by definition. I don't know if I agree with this, and it would seem to make more sense to allow them to assign weights to respective tasks (but keeping the default of evenly weighted tasks). Given the original implementation, I needed to do some workarounds to make Metrics not use the sample and task weights.

In summary:

  • It would be nice if users could simply supply an arbitrary LossFunction as input to the metrics argument of MPNN
  • this would require flipping the inheritance scheme, i.e., Metric < LossFunction (read: "LossFunction inherits from Metric)
    • to me this makes sense. If a Metric is a class of function $f : \mathbb R \times \mathbb R \to \mathbb R$ that is possibly non-differentiable, then a LossFunction is a specific subclass of function $g$ that is differentiable. The reverse is not true

I have a few ideas on this, but I'm curious to hear what others think.

@shihchengli
Copy link
Contributor

I'd prefer to decouple Metric from LossFunction as they serve different purposes. I can see that we would like minimal duplication of code between these two functions. Flipping the inheritance so that LossFunction inherits from Metric can avoid code duplication, but that implies all loss functions are metrics, although not all metrics are suitable as loss functions (e.g., MVELoss). We could use mixins to achieve both code reuse and avoid inheritance.

@davidegraff
Copy link
Contributor

I'd prefer to decouple Metric from LossFunction as they serve different purposes. I can see that we would like minimal duplication of code between these two functions. Flipping the inheritance so that LossFunction inherits from Metric can avoid code duplication, but that implies all loss functions are metrics, although not all metrics are suitable as loss functions (e.g., MVELoss). We could use mixins to achieve both code reuse and avoid inheritance.

A couple of points:

  • If you were train any vanilla pytorch regression model and evaluate its performance on a validation set, what classes/function would you use to monitor early stopping? The answer is definitively nn.MSELoss, and the same generally goes for classification models as well (when you don't know the important test-time metric). There is very clearly a thematic link between evaluation metrics and loss functions, and the code should reflect that.
  • No, not all metrics are suitable as loss functions, which is what I mentioned in my post above. A loss function, by definition, must be differentiable to be used in back-propagation/gradient descent--a metric does not, but it can be. That is, to measure regression performance, you can use MSE, a correlation coefficient, etc. Some of these will be differentiable, but not all.
  • I'm generally not a fan of mixins. I used them originally, but they're a clunky way to avoid code reuse and muddles the inheritance structure when used with multiple inheritance (in that it's not clear which superclass defines the original method, and in the case that both of them do, a reader has to know the proper MRO)

@shihchengli
Copy link
Contributor

shihchengli commented Apr 20, 2024

I agree that there is a strong correlation between evaluation metrics and loss functions. I originally wanted to say that not all loss functions are suitable as metrics (e.g., MVELoss, EvidentialLoss). Therefore, I oppose allowing users to "supply an arbitrary LossFunction as input to the metrics argument of MPNN". Yes, it seems the inheritance structure would be quite complex for some metrics, like BoundedMAEMetric, which already includes one mixin.

@davidegraff
Copy link
Contributor

I originally wanted to say that not all loss functions are suitable as metrics (e.g., MVELoss, EvidentialLoss). Therefore, I oppose allowing users to "supply an arbitrary LossFunction as input to the metrics argument of MPNN".

I don't think this statement is correct. For example, if you were training an MVE model with early stopping, you wouldn't want to monitor just RMSE because it doesn't capture whether your model is starting to overfit. That is, your model could optimize towards the trivial solution ($\hat \sigma^2 \to 0$), but this would not be the loss-minimizing solution because the metric you're monitoring for early stopping is not equivalent to the objective you're using to train your model.

@davidegraff davidegraff changed the title [TODO]: Metric Improvements [TODO]: Metric Improvements Apr 24, 2024
@davidegraff
Copy link
Contributor

davidegraff commented Apr 24, 2024

I've been thinking about my original comment some more, mainly this point:

  • It would be nice if users could simply supply an arbitrary LossFunction as input to the metrics argument of MPNN

  • this would require flipping the inheritance scheme, i.e., Metric < LossFunction (read: "LossFunction inherits from Metric)

    • to me this makes sense. If a Metric is a class of function $f : \mathbb R \times \mathbb R \to \mathbb R$ that is possibly non-differentiable, then a LossFunction is a specific subclass of function $g$ that is differentiable. The reverse is not true

And I disagree with it now. Differentiable functions are indeed a subset of Non-differentiable functions, but in code this would reverse the inheritance structure. That is, if we have some class of functions $f$ that is a subset of the class of functions $g$, then $f &lt; g$. In code, the $f &lt; g$ relation is expressed in code via class G extends F or in python via: class G(F). So, I'm back to my original design decision in that the base class should be LossFunction and we have a subclass Metric (or MetricFunction) that inherits from it:

class LossFunction(nn.Module):
    ...

class MetricFunction(LossFunction):
    ...

To be clear, a user can use a non-differentiable loss function, but the gradients won't propagate to the final loss value (i.e., they'll stop right at the non-differentiable step.) But we should decide on whether we even need a subclass MetricFunction or if we can just call everything a LossFunction with some of them being non-differentiable.

@KnathanM
Copy link
Contributor Author

KnathanM commented May 9, 2024

Side note: also see #612 for a new loss function that could be added as we improve the metrics and loss functions

@SteshinSS
Copy link

Just my 2 cents: At this moment, Chemprop v2 MPNN uses self.criterion to calculate train_loss and the first metric in self.metrics to calculate val_loss. This is either a bug or a very confusing decision.
https://github.com/chemprop/chemprop/blob/main/chemprop/models/model.py#L161

@KnathanM
Copy link
Contributor Author

Hi Simon, thanks for the feedback. This isn't a bug, so we'd like to have it more clear in our documentation why we implemented it that way. Maybe we should call it val_metric instead of val_loss. Or also do you think adding something like this to our documentation would make it a less confusing decision?

During training, a differentiable function (loss function/criterion) is used to calculate the deviation of the current model's predictions from the true values in the training batch. The gradient of this loss is used to update the model weights. Often a different, possibly non-differentiable, function will be used to evaluate the final model's performance. This test metric should also be used to calculate the performance of the model on the validation dataset so that the best model weights are saved during training. Chemprop does this automatically in the CLI. Refer to the lightning documentation for early stopping and checkpointing to see how to do this in a notebook.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
todo add an item to the to-do list
Projects
None yet
Development

No branches or pull requests

4 participants