Skip to content

Commit

Permalink
Fixing the _categorical_exact function and the tests to be more deter…
Browse files Browse the repository at this point in the history
…ministic (#1059)

Thanks so much for finding this issue and submitting a great PR fixing it.
  • Loading branch information
savyajha committed Aug 15, 2023
1 parent b8a0c38 commit f17b1d2
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 30 deletions.
4 changes: 2 additions & 2 deletions pomegranate/bayesian_network.py
Expand Up @@ -954,8 +954,8 @@ def _categorical_exact(X, sample_weight=None, include_parents=None,
order_graph.add_edge(parent, subset, weight=weight,
structure=structure)

path = nx.shortest_path(order_graph, source=(), target=tuple(range(d)),
weight='weight')
path = sorted(nx.all_shortest_paths(order_graph, source=(),
target=tuple(range(d)), weight="weight"))[1]

score, structure = 0, list( None for i in range(d) )
for u, v in zip(path[:-1], path[1:]):
Expand Down
26 changes: 13 additions & 13 deletions tests/test_bayesian_network.py
Expand Up @@ -1030,33 +1030,33 @@ def test_learn_structure_exact(X):
model = BayesianNetwork(algorithm='exact')
model.fit(X)

assert_tuple_equal(model._parents, ((), (), (), (1, 2)))
assert_tuple_equal(model._parents, ((), (0, 2), (), ()))

assert_array_almost_equal(model.distributions[0].probs,
[[0.4545, 0.5455]], 4)
assert_array_almost_equal(model.distributions[1].probs,
assert_array_almost_equal(model.distributions[3].probs,
[[0.4545, 0.5455]], 4)
assert_array_almost_equal(model.distributions[2].probs,
[[0.5455, 0.4545]], 4)
assert_array_almost_equal(model.distributions[3].probs[0],
[[[1., 0.],
[1., 0.]],
assert_array_almost_equal(model.distributions[1].probs[0],
[[[0.3333, 0.6667],
[1.0000, 0.0000]],

[[0., 1.],
[0., 1.]]], 4)
[[0.3333, 0.6667],
[0.3333, 0.6667]]], 4)

assert_array_almost_equal(model._factor_graph.factors[0].probs,
[[0.4545, 0.5455]], 4)
assert_array_almost_equal(model._factor_graph.factors[1].probs,
assert_array_almost_equal(model._factor_graph.factors[3].probs,
[[0.4545, 0.5455]], 4)
assert_array_almost_equal(model._factor_graph.factors[2].probs,
[[0.5455, 0.4545]], 4)
assert_array_almost_equal(model._factor_graph.factors[3].probs,
[[[0.25, 0.],
[0.25, 0.]],
assert_array_almost_equal(model._factor_graph.factors[1].probs,
[[[0.0833, 0.1667],
[0.2500, 0.0000]],

[[0., 0.25],
[0., 0.25]]], 4)
[[0.0833, 0.1667],
[0.0833, 0.1667]]], 4)


def test_summarize(X, distributions):
Expand Down
30 changes: 15 additions & 15 deletions tests/test_bayesian_network_structure_learning.py
Expand Up @@ -152,18 +152,18 @@ def test_categorical_chow_liu_raises(X, w):

def test_categorical_exact(X):
structure = _categorical_exact(X)
assert_tuple_equal(structure, ((), (2,), (), (1, 2)))
assert_tuple_equal(structure, ((), (0,), (), (0, 1)))

structure = _categorical_exact(X, max_parents=1)
assert_tuple_equal(structure, ((), (2,), (), ()))
assert_tuple_equal(structure, ((), (0,), (), ()))


def test_categorical_exact_weighted(X, w):
structure = _categorical_exact(X, w)
assert_tuple_equal(structure, ((), (2,), (), (1, 2)))
assert_tuple_equal(structure, ((), (0,), (), (0, 1)))

structure = _categorical_exact(X, w, max_parents=1)
assert_tuple_equal(structure, ((), (2,), (), ()))
assert_tuple_equal(structure, ((), (0,), (), ()))


def test_categorical_exact_exclude_parents(X):
Expand All @@ -173,35 +173,35 @@ def test_categorical_exact_exclude_parents(X):

structure = _categorical_exact(X, exclude_parents=exclude_parents,
max_parents=1)
assert_tuple_equal(structure, ((), (), (), (0,)))
assert_tuple_equal(structure, ((), (), (0,), ()))

exclude_parents = ((), (2,), (), (0, 1))
structure = _categorical_exact(X, exclude_parents=exclude_parents)
assert_tuple_equal(structure, ((2, 3), (), (), (2,)))
assert_tuple_equal(structure, ((3,), (), (0,3), ()))


def test_categorical_exact_large():
numpy.random.seed(0)
X = numpy.random.randint(3, size=(50, 10))

structure = _categorical_exact(X)
assert_tuple_equal(structure, ((), (6, 7), (), (1, 6, 7), (), (1, 3, 6, 7),
(), (6,), (), ()))
assert_tuple_equal(structure, ((1, 2), (), (), (), (0, 1, 2), (), (),
(), (0, 1, 2, 4), ()))

structure = _categorical_exact(X, max_parents=1)
assert_tuple_equal(structure, ((), (), (), (), (), (), (), (3,), (), ()))
assert_tuple_equal(structure, ((), (), (), (), (), (), (), (2,), (), ()))

structure = _categorical_exact(X, max_parents=2)
assert_tuple_equal(structure, ((), (), (6, 7), (), (), (), (), (6,), (),
assert_tuple_equal(structure, ((), (), (), (), (), (), (), (0, 1), (),
()))

structure = _categorical_exact(X, max_parents=3)
assert_tuple_equal(structure, ((), (), (), (), (), (6, 7), (), (6,),
(5, 6, 7), ()))
assert_tuple_equal(structure, ((1, 2), (), (), (), (0, 1, 2), (), (), (),
(), ()))

structure = _categorical_exact(X, max_parents=4)
assert_tuple_equal(structure, ((), (6, 7), (), (1, 6, 7), (), (1, 3, 6, 7),
(), (6,), (), ()))
assert_tuple_equal(structure, ((1, 2), (), (), (), (0, 1, 2), (), (), (),
(0, 1, 2, 4), ()))


def test_categorical_exact_large_pseudocount():
Expand All @@ -224,7 +224,7 @@ def test_categorical_exact_large_pseudocount():
assert_tuple_equal(structure, ((), (), (), (), (), (), (), ()))

structure = _categorical_exact(X, pseudocount=1e-8)
assert_tuple_equal(structure, ((5,), (), (), (), (), (), (), ()))
assert_tuple_equal(structure, ((), (), (), (), (), (0,), (), ()))


###
Expand Down

0 comments on commit f17b1d2

Please sign in to comment.