From f17b1d2b2897be99efc7aae692fbecd8be8ce778 Mon Sep 17 00:00:00 2001 From: Savyasachee Jha Date: Tue, 15 Aug 2023 23:37:39 +0530 Subject: [PATCH] Fixing the _categorical_exact function and the tests to be more deterministic (#1059) Thanks so much for finding this issue and submitting a great PR fixing it. --- pomegranate/bayesian_network.py | 4 +-- tests/test_bayesian_network.py | 26 ++++++++-------- ...est_bayesian_network_structure_learning.py | 30 +++++++++---------- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/pomegranate/bayesian_network.py b/pomegranate/bayesian_network.py index c3e9c13c..05274957 100644 --- a/pomegranate/bayesian_network.py +++ b/pomegranate/bayesian_network.py @@ -954,8 +954,8 @@ def _categorical_exact(X, sample_weight=None, include_parents=None, order_graph.add_edge(parent, subset, weight=weight, structure=structure) - path = nx.shortest_path(order_graph, source=(), target=tuple(range(d)), - weight='weight') + path = sorted(nx.all_shortest_paths(order_graph, source=(), + target=tuple(range(d)), weight="weight"))[1] score, structure = 0, list( None for i in range(d) ) for u, v in zip(path[:-1], path[1:]): diff --git a/tests/test_bayesian_network.py b/tests/test_bayesian_network.py index ea827fbe..c33a55f6 100644 --- a/tests/test_bayesian_network.py +++ b/tests/test_bayesian_network.py @@ -1030,33 +1030,33 @@ def test_learn_structure_exact(X): model = BayesianNetwork(algorithm='exact') model.fit(X) - assert_tuple_equal(model._parents, ((), (), (), (1, 2))) + assert_tuple_equal(model._parents, ((), (0, 2), (), ())) assert_array_almost_equal(model.distributions[0].probs, [[0.4545, 0.5455]], 4) - assert_array_almost_equal(model.distributions[1].probs, + assert_array_almost_equal(model.distributions[3].probs, [[0.4545, 0.5455]], 4) assert_array_almost_equal(model.distributions[2].probs, [[0.5455, 0.4545]], 4) - assert_array_almost_equal(model.distributions[3].probs[0], - [[[1., 0.], - [1., 0.]], + assert_array_almost_equal(model.distributions[1].probs[0], + [[[0.3333, 0.6667], + [1.0000, 0.0000]], - [[0., 1.], - [0., 1.]]], 4) + [[0.3333, 0.6667], + [0.3333, 0.6667]]], 4) assert_array_almost_equal(model._factor_graph.factors[0].probs, [[0.4545, 0.5455]], 4) - assert_array_almost_equal(model._factor_graph.factors[1].probs, + assert_array_almost_equal(model._factor_graph.factors[3].probs, [[0.4545, 0.5455]], 4) assert_array_almost_equal(model._factor_graph.factors[2].probs, [[0.5455, 0.4545]], 4) - assert_array_almost_equal(model._factor_graph.factors[3].probs, - [[[0.25, 0.], - [0.25, 0.]], + assert_array_almost_equal(model._factor_graph.factors[1].probs, + [[[0.0833, 0.1667], + [0.2500, 0.0000]], - [[0., 0.25], - [0., 0.25]]], 4) + [[0.0833, 0.1667], + [0.0833, 0.1667]]], 4) def test_summarize(X, distributions): diff --git a/tests/test_bayesian_network_structure_learning.py b/tests/test_bayesian_network_structure_learning.py index 7f2dd3f2..e62d8c29 100644 --- a/tests/test_bayesian_network_structure_learning.py +++ b/tests/test_bayesian_network_structure_learning.py @@ -152,18 +152,18 @@ def test_categorical_chow_liu_raises(X, w): def test_categorical_exact(X): structure = _categorical_exact(X) - assert_tuple_equal(structure, ((), (2,), (), (1, 2))) + assert_tuple_equal(structure, ((), (0,), (), (0, 1))) structure = _categorical_exact(X, max_parents=1) - assert_tuple_equal(structure, ((), (2,), (), ())) + assert_tuple_equal(structure, ((), (0,), (), ())) def test_categorical_exact_weighted(X, w): structure = _categorical_exact(X, w) - assert_tuple_equal(structure, ((), (2,), (), (1, 2))) + assert_tuple_equal(structure, ((), (0,), (), (0, 1))) structure = _categorical_exact(X, w, max_parents=1) - assert_tuple_equal(structure, ((), (2,), (), ())) + assert_tuple_equal(structure, ((), (0,), (), ())) def test_categorical_exact_exclude_parents(X): @@ -173,11 +173,11 @@ def test_categorical_exact_exclude_parents(X): structure = _categorical_exact(X, exclude_parents=exclude_parents, max_parents=1) - assert_tuple_equal(structure, ((), (), (), (0,))) + assert_tuple_equal(structure, ((), (), (0,), ())) exclude_parents = ((), (2,), (), (0, 1)) structure = _categorical_exact(X, exclude_parents=exclude_parents) - assert_tuple_equal(structure, ((2, 3), (), (), (2,))) + assert_tuple_equal(structure, ((3,), (), (0,3), ())) def test_categorical_exact_large(): @@ -185,23 +185,23 @@ def test_categorical_exact_large(): X = numpy.random.randint(3, size=(50, 10)) structure = _categorical_exact(X) - assert_tuple_equal(structure, ((), (6, 7), (), (1, 6, 7), (), (1, 3, 6, 7), - (), (6,), (), ())) + assert_tuple_equal(structure, ((1, 2), (), (), (), (0, 1, 2), (), (), + (), (0, 1, 2, 4), ())) structure = _categorical_exact(X, max_parents=1) - assert_tuple_equal(structure, ((), (), (), (), (), (), (), (3,), (), ())) + assert_tuple_equal(structure, ((), (), (), (), (), (), (), (2,), (), ())) structure = _categorical_exact(X, max_parents=2) - assert_tuple_equal(structure, ((), (), (6, 7), (), (), (), (), (6,), (), + assert_tuple_equal(structure, ((), (), (), (), (), (), (), (0, 1), (), ())) structure = _categorical_exact(X, max_parents=3) - assert_tuple_equal(structure, ((), (), (), (), (), (6, 7), (), (6,), - (5, 6, 7), ())) + assert_tuple_equal(structure, ((1, 2), (), (), (), (0, 1, 2), (), (), (), + (), ())) structure = _categorical_exact(X, max_parents=4) - assert_tuple_equal(structure, ((), (6, 7), (), (1, 6, 7), (), (1, 3, 6, 7), - (), (6,), (), ())) + assert_tuple_equal(structure, ((1, 2), (), (), (), (0, 1, 2), (), (), (), + (0, 1, 2, 4), ())) def test_categorical_exact_large_pseudocount(): @@ -224,7 +224,7 @@ def test_categorical_exact_large_pseudocount(): assert_tuple_equal(structure, ((), (), (), (), (), (), (), ())) structure = _categorical_exact(X, pseudocount=1e-8) - assert_tuple_equal(structure, ((5,), (), (), (), (), (), (), ())) + assert_tuple_equal(structure, ((), (), (), (), (), (0,), (), ())) ###