-
-
Notifications
You must be signed in to change notification settings - Fork 25k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEA D2 Brier Score #28971
base: main
Are you sure you want to change the base?
FEA D2 Brier Score #28971
Conversation
}, | ||
prefer_skip_nested_validation=True, | ||
) | ||
def d2_brier_score(y_true, y_proba, *, sample_weight=None, pos_label=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def d2_brier_score(y_true, y_proba, *, sample_weight=None, pos_label=None): | |
def d2_brier_score(y_true, y_proba, *, sample_weight=None, labels=None): |
This should be able to handle multiclass, too. See log_loss
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How are we supposed to extract the pos_label to pass to brier_score_loss?
y_true = column_or_1d(y_true) | ||
positive_label = _get_positive_label_for_brier_score(y_true, pos_label) | ||
weights = _check_sample_weight(sample_weight, y_true) | ||
positive_prob = np.sum((y_true == positive_label) * weights) / np.sum(weights) | ||
y_proba_ref = np.full(y_true.shape, positive_prob) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be the same code as in d2_log_loss.
def _get_positive_label_for_brier_score(y_true, pos_label=None): | ||
try: | ||
pos_label = _check_pos_label_consistency(pos_label, y_true) | ||
except ValueError: | ||
classes = np.unique(y_true) | ||
if classes.dtype.kind not in ("O", "U", "S"): | ||
# for backward compatibility, if classes are not string then | ||
# `pos_label` will correspond to the greater label | ||
pos_label = classes[-1] | ||
else: | ||
raise | ||
|
||
return pos_label |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _get_positive_label_for_brier_score(y_true, pos_label=None): | |
try: | |
pos_label = _check_pos_label_consistency(pos_label, y_true) | |
except ValueError: | |
classes = np.unique(y_true) | |
if classes.dtype.kind not in ("O", "U", "S"): | |
# for backward compatibility, if classes are not string then | |
# `pos_label` will correspond to the greater label | |
pos_label = classes[-1] | |
else: | |
raise | |
return pos_label |
#22046 seems like a blocker |
Reference Issues/PRs
Closes #20943
What does this implement/fix? Explain your changes.
Any other comments?
CC: @lorentzenchr