-
Notifications
You must be signed in to change notification settings - Fork 10
/
probabilistic_readout.py
110 lines (78 loc) · 3.73 KB
/
probabilistic_readout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from typing import Tuple, Optional, List
import torch
from pydgn.experiment.util import s2c
class ProbabilisticReadout(torch.nn.Module):
def __init__(self, dim_node_features, dim_edge_features, dim_target, config):
super().__init__()
self.K = dim_node_features
self.Y = dim_target
self.E = dim_edge_features
self.eps = 1e-8
def init_accumulators(self):
raise NotImplementedError()
def e_step(self, p_Q, x_labels, y_labels, batch):
raise NotImplementedError()
def infer(self, p_Q, x_labels, batch):
raise NotImplementedError()
def complete_log_likelihood(self, posterior, emission_target, batch):
raise NotImplementedError()
def _m_step(self, x_labels, y_labels, posterior, batch):
raise NotImplementedError()
def m_step(self):
raise NotImplementedError()
class ProbabilisticNodeReadout(ProbabilisticReadout):
def __init__(self, dim_node_features, dim_edge_features, dim_target, config):
super().__init__(dim_node_features, dim_edge_features, dim_target, config)
self.emission_class = s2c(config['emission'])
self.CN = config['C'] # number of states of a generic node
self.emission = self.emission_class(self.Y, self.CN)
def init_accumulators(self):
self.emission.init_accumulators()
def e_step(self, p_Q, x_labels, y_labels, batch):
emission_target = self.emission.e_step(x_labels, y_labels) # ?n x CN
readout_posterior = emission_target
# true log P(y) using the observables
# Mean of individual node terms
p_x = (p_Q * readout_posterior).sum(dim=1)
p_x[p_x == 0.] = 1.
true_log_likelihood = p_x.log().sum(dim=0)
return true_log_likelihood, readout_posterior, emission_target
def infer(self, p_Q, x_labels, batch):
return self.emission.infer(p_Q, x_labels)
def complete_log_likelihood(self, eui, emission_target, batch):
complete_log_likelihood = (eui * (emission_target.log())).sum(1).sum()
return complete_log_likelihood
def _m_step(self, x_labels, y_labels, eui, batch):
self.emission._m_step(x_labels, y_labels, eui)
def m_step(self):
self.emission.m_step()
self.init_accumulators()
class UnsupervisedProbabilisticNodeReadout(ProbabilisticReadout):
def __init__(self, dim_node_features, dim_edge_features, dim_target, config):
super().__init__(dim_node_features, dim_edge_features, dim_target, config)
self.emission_class = s2c(config['emission'])
self.CN = config['C'] # number of states of a generic node
self.emission = self.emission_class(self.K, self.CN)
def init_accumulators(self):
self.emission.init_accumulators()
def e_step(self, p_Q, x_labels, y_labels, batch):
# Pass x_labels as y_labels
emission_target = self.emission.e_step(x_labels, x_labels) # ?n x CN
readout_posterior = emission_target
# true log P(y) using the observables
# Mean of individual node terms
p_x = (p_Q * readout_posterior).sum(dim=1)
p_x[p_x == 0.] = 1.
true_log_likelihood = p_x.log().sum(dim=0)
return true_log_likelihood, readout_posterior, emission_target
def infer(self, p_Q, x_labels, batch):
return self.emission.infer(p_Q, x_labels)
def complete_log_likelihood(self, eui, emission_target, batch):
complete_log_likelihood = (eui * (emission_target.log())).sum(1).sum()
return complete_log_likelihood
def _m_step(self, x_labels, y_labels, eui, batch):
# Pass x_labels as y_labels
self.emission._m_step(x_labels, x_labels, eui)
def m_step(self):
self.emission.m_step()
self.init_accumulators()