Skip to content

Commit

Permalink
bayes nets working!
Browse files Browse the repository at this point in the history
  • Loading branch information
jmschrei committed Jan 3, 2015
1 parent 1227edd commit 70e589f
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 52 deletions.
60 changes: 44 additions & 16 deletions pomegranate/bayesnet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,27 @@ def exp(value):

return numpy.exp(value)

def merge_marginals( marginals ):
'''
Merge multiple marginals of the same distribution to form a more informed
distribution.
'''

probabilities = { key: _log( value ) for key, value in marginals[0].parameters[0].items() }

for marginal in marginals[1:]:
for key, value in marginal.parameters[0].items():
probabilities[key] += _log( value )

total = NEGINF
for key, value in probabilities.items():
total = pair_lse( total, probabilities[key] )

for key, value in probabilities.items():
probabilities[key] = cexp( value - total )

return DiscreteDistribution( probabilities )

cdef class BayesianNetwork( Model ):
"""
Represents a Bayesian Network
Expand Down Expand Up @@ -197,58 +218,65 @@ cdef class BayesianNetwork( Model ):
names = { state.name: state.distribution for state in self.states }
data = { names[state]: value for state, value in data.items() }

factors = [ data[ s.distribution ] if s.distribution in data else None for s in self.states ]
factors = [ data[ s.distribution ] if s.distribution in data else s.distribution.marginal() for s in self.states ]
new_factors = [ i for i in factors ]

in_edges = numpy.array( self.in_edge_count )
out_edges = numpy.array( self.out_edge_count )

messages = [ None for i in in_edges ]

leaves = numpy.where( out_edges[1:] - out_edges[:-1] == 0 )[0]
visited = numpy.zeros( len( self.states ) )
visited[leaves] = 1
for i, factor in enumerate( factors ):
if factor is not None and not isinstance( factor, Distribution ):
for i, s in enumerate( self.states ):
if s.distribution in data and not isinstance( data[ s.distribution ], Distribution ):
visited[i] = 1

while True:
for i, state in enumerate( self.states ):
if visited[ i ] == 1:
if visited[i] == 1:
continue

state = self.states[ i ]
state = self.states[i]
d = state.distribution

for k in xrange( out_edges[i], out_edges[i+1] ):
ki = self.out_transitions[k]
if visited[ki] == 0:
break
else:
parents = {}
for k in xrange( in_edges[i], in_edges[i+1] ):
ki = self.in_transitions[k]
parents[ self.states[ki].distribution ] = factors[ki]

for k in xrange( out_edges[i], out_edges[i+1] ):
ki = self.out_transitions[k]

if not isinstance( factors[i], Distribution ):
factors[i] = self.states[ki].distribution.marginal( parents, wrt=d, value=factors[ki] )
else:
factors[i] = self.states[ki].distribution.marginal( parents, wrt=factors[i], value=factors[ki] )
parents = {}
for l in xrange( in_edges[ki], in_edges[ki+1] ):
li = self.in_transitions[l]
parents[ self.states[li].distribution ] = factors[li]

messages[k] = self.states[ki].distribution.marginal( parents, wrt=d, value=new_factors[ki] )
else:
local_messages = [ factors[i] ] + [ messages[k] for k in xrange( out_edges[i], out_edges[i+1] ) ]
new_factors[i] = merge_marginals( local_messages )

visited[i] = 1

if visited.sum() == visited.shape[0]:
break

return factors
return new_factors


def forward_backward( self, data={} ):
'''
...
'''

print "FORWARD LOGS"
factors = self.forward( data )
print factors
data = { self.states[i].name: factors[i] for i in xrange( len(factors) ) }

print "\n".join( map( str, factors ) )
print
print "BACKWARD LOGS"
return self.backward( data )
8 changes: 6 additions & 2 deletions pomegranate/distributions.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1826,7 +1826,7 @@ cdef class ConditionalDiscreteDistribution( ConditionalDistribution ):
'''

d, pd, keys = self.parameters

if not wrt:
return DiscreteDistribution({ key: math.e ** recursive_discrete_log_probability( key, d, parent_values, pd )
for key in keys })
Expand All @@ -1851,9 +1851,13 @@ cdef class ConditionalDiscreteDistribution( ConditionalDistribution ):
vkey, d, pv, pd ) + value.log_probability( vkey ) )
probabilities[ key ] = sum( vprobabilities.values() )

total = sum( probabilities.values() )
for key, value in probabilities.items():
probabilities[ key ] = value / total

return DiscreteDistribution( probabilities )

def log_probability( self, symbol, parent_values ):
def log_probability( self, symbol, parent_values={} ):
'''
Calculate the log probability of the symbol given the parents. If the
parent is not observed, marginalize over that distribution. This is
Expand Down
104 changes: 70 additions & 34 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,10 @@
import math
from pomegranate import *
'''
a = DiscreteDistribution( { 'T': 0.90, 'H': 0.10 } )
b = DiscreteDistribution( { 'T': 0.40, 'H': 0.60 } )
c = ConditionalDiscreteDistribution(
{ 'T' : { 'T' : DiscreteDistribution( { 'T' : 0.99, 'H' : 0.01 } ),
'H' : DiscreteDistribution( { 'T' : 0.20, 'H' : 0.80 } ) },
'H' : { 'T' : DiscreteDistribution( { 'T' : 0.67, 'H' : 0.33 } ),
'H' : DiscreteDistribution( { 'T' : 0.37, 'H' : 0.63 } ) }
}, [ a, b ], [ 'T', 'H' ] )
s1 = State( a, name="s1" )
s2 = State( b, name="s2" )
s3 = State( c, name="s3" )

network = BayesianNetwork( "test" )
network.add_states( [s1, s2, s3] )
network.add_transition( s1, s3, 1.0 )
network.add_transition( s2, s3, 1.0 )
network.bake()
print "RANDOM EXAMPLE"
print
print math.e ** c.log_probability( 'T', { a : 'T', b : 'H', c : 'T' } )
print
print math.e ** network.log_probability( { "s1" : 'T', "s2" : 'H', "s3" : 'T' } )
print
print network.belief_propogation( { "s1": 'T', "s2": 'H' } )
#######################
# MODIFIED MONTY HALL #
#######################
'''


friend = DiscreteDistribution( { 'A': 1./3, 'B': 1./3, 'C': 1./3 } )
guest = ConditionalDiscreteDistribution( {
Expand Down Expand Up @@ -67,9 +41,71 @@
network.bake()
print "\t".join([ state.name for state in network.states ])
print "\n".join( map( str, network.forward_backward( { 'friend' : 'A', 'monty' : 'B' } ) ) )
'''
################
# ASIA EXAMPLE #
################

asia = DiscreteDistribution({ 'True' : 0.5, 'False' : 0.5 })
tuberculosis = ConditionalDiscreteDistribution({
'True' : DiscreteDistribution({ 'True' : 0.2, 'False' : 0.80 }),
'False' : DiscreteDistribution({ 'True' : 0.01, 'False' : 0.99 })
}, [asia])

smoking = DiscreteDistribution({ 'True' : 0.5, 'False' : 0.5 })
lung = ConditionalDiscreteDistribution({
'True' : DiscreteDistribution({ 'True' : 0.75, 'False' : 0.25 }),
'False' : DiscreteDistribution({ 'True' : 0.02, 'False' : 0.98 })
}, [smoking] )
bronchitis = ConditionalDiscreteDistribution({
'True' : DiscreteDistribution({ 'True' : 0.92, 'False' : 0.08 }),
'False' : DiscreteDistribution({ 'True' : 0.03, 'False' : 0.97})
}, [smoking] )

tuberculosis_or_cancer = ConditionalDiscreteDistribution({
'True' : { 'True' : DiscreteDistribution({ 'True' : 1.0, 'False' : 0.0 }),
'False' : DiscreteDistribution({ 'True' : 1.0, 'False' : 0.0 }),
},
'False' : { 'True' : DiscreteDistribution({ 'True' : 1.0, 'False' : 0.0 }),
'False' : DiscreteDistribution({ 'True' : 0.0, 'False' : 1.0 })
}
}, [tuberculosis, lung] )

xray = ConditionalDiscreteDistribution({
'True' : DiscreteDistribution({ 'True' : .885, 'False' : .115 }),
'False' : DiscreteDistribution({ 'True' : 0.04, 'False' : 0.96 })
}, [tuberculosis_or_cancer] )

dyspnea = ConditionalDiscreteDistribution({
'True' : { 'True' : DiscreteDistribution({ 'True' : 0.96, 'False' : 0.04 }),
'False' : DiscreteDistribution({ 'True' : 0.89, 'False' : 0.11 })
},
'False' : { 'True' : DiscreteDistribution({ 'True' : 0.82, 'False' : 0.18 }),
'False' : DiscreteDistribution({ 'True' : 0.4, 'False' : 0.6 })
}
}, [tuberculosis_or_cancer, bronchitis])

s0 = State( asia, name="asia" )
s1 = State( tuberculosis, name="tuberculosis" )
s2 = State( smoking, name="smoker" )
s3 = State( lung, name="cancer" )
s4 = State( bronchitis, name="bronchitis" )
s5 = State( tuberculosis_or_cancer, name="TvC" )
s6 = State( xray, name="xray" )
s7 = State( dyspnea, name='dyspnea' )

network = BayesianNetwork( "asia" )
network.add_states([ s0, s1, s2, s3, s4, s5, s6, s7 ])
network.add_transition( s0, s1, 1.0 )
network.add_transition( s1, s5, 1.0 )
network.add_transition( s2, s3, 1.0 )
network.add_transition( s2, s4, 1.0 )
network.add_transition( s3, s5, 1.0 )
network.add_transition( s5, s6, 1.0 )
network.add_transition( s5, s7, 1.0 )
network.add_transition( s4, s7, 1.0 )
network.bake()


#print monty.marginal( wrt=guest, value='A' )

print "\n".join( map( str, network.forward_backward( { 'monty' : 'A', 'friend' : 'B' } ) ) )

print "\t".join([ state.name for state in network.states ])
print "\n".join( map( str, network.forward_backward({ 'tuberculosis' : 'True', 'smoker' : 'False', 'bronchitis' : DiscreteDistribution({ 'True' : 0.8, 'False' : 0.2 }) }) ) )

0 comments on commit 70e589f

Please sign in to comment.