-
Notifications
You must be signed in to change notification settings - Fork 10
/
loss.py
45 lines (35 loc) · 1.56 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from pydgn.training.callback.metric import Metric
class CGMMLoss(Metric):
@property
def name(self) -> str:
return 'CGMM Loss'
def __init__(self, use_as_loss=True, reduction='mean', use_nodes_batch_size=True):
super().__init__(use_as_loss=use_as_loss, reduction=reduction, use_nodes_batch_size=use_nodes_batch_size)
self.old_likelihood = -float('inf')
self.new_likelihood = None
def on_training_batch_end(self, state):
self.batch_metrics.append(state.batch_loss[self.name].item())
if state.model.is_graph_classification:
self.num_samples += state.batch_num_targets
else:
# This works for unsupervised CGMM
self.num_samples += state.batch_num_nodes
def on_training_epoch_end(self, state):
super().on_training_epoch_end(state)
if (state.epoch_loss[self.name].item() - self.old_likelihood) < 0:
pass
# tate.stop_training = True
self.old_likelihood = state.epoch_loss[self.name].item()
def on_eval_batch_end(self, state):
self.batch_metrics.append(state.batch_loss[self.name].item())
if state.model.is_graph_classification:
self.num_samples += state.batch_num_targets
else:
# This works for unsupervised CGMM
self.num_samples += state.batch_num_nodes
# Simply ignore targets
def forward(self, targets, *outputs):
likelihood = outputs[2]
return likelihood
def on_backward(self, state):
pass