From 02b8b3036f18816190ed3e89a29e053bc8962016 Mon Sep 17 00:00:00 2001 From: Savyasachee Jha Date: Sat, 26 Aug 2023 02:46:50 +0530 Subject: [PATCH] Fix typo in the _categorical_exact function (#1060) This typo would cause the function to error if there was only a single shortest path. --- pomegranate/bayesian_network.py | 2 +- tests/test_bayesian_network.py | 22 +++++++++---------- ...est_bayesian_network_structure_learning.py | 10 ++++----- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/pomegranate/bayesian_network.py b/pomegranate/bayesian_network.py index 05274957..dbc3d8c4 100644 --- a/pomegranate/bayesian_network.py +++ b/pomegranate/bayesian_network.py @@ -955,7 +955,7 @@ def _categorical_exact(X, sample_weight=None, include_parents=None, structure=structure) path = sorted(nx.all_shortest_paths(order_graph, source=(), - target=tuple(range(d)), weight="weight"))[1] + target=tuple(range(d)), weight="weight"))[0] score, structure = 0, list( None for i in range(d) ) for u, v in zip(path[:-1], path[1:]): diff --git a/tests/test_bayesian_network.py b/tests/test_bayesian_network.py index c33a55f6..d8945c2c 100644 --- a/tests/test_bayesian_network.py +++ b/tests/test_bayesian_network.py @@ -1030,33 +1030,33 @@ def test_learn_structure_exact(X): model = BayesianNetwork(algorithm='exact') model.fit(X) - assert_tuple_equal(model._parents, ((), (0, 2), (), ())) + assert_tuple_equal(model._parents, ((), (), (0, 1), ())) assert_array_almost_equal(model.distributions[0].probs, [[0.4545, 0.5455]], 4) - assert_array_almost_equal(model.distributions[3].probs, + assert_array_almost_equal(model.distributions[1].probs, [[0.4545, 0.5455]], 4) - assert_array_almost_equal(model.distributions[2].probs, - [[0.5455, 0.4545]], 4) - assert_array_almost_equal(model.distributions[1].probs[0], + assert_array_almost_equal(model.distributions[3].probs, + [[0.4545, 0.5455]], 4) + assert_array_almost_equal(model.distributions[2].probs[0], [[[0.3333, 0.6667], [1.0000, 0.0000]], - [[0.3333, 0.6667], - [0.3333, 0.6667]]], 4) + [[0.5, 0.5], + [0.5, 0.5]]], 4) assert_array_almost_equal(model._factor_graph.factors[0].probs, [[0.4545, 0.5455]], 4) + assert_array_almost_equal(model._factor_graph.factors[1].probs, + [[0.4545, 0.5455]], 4) assert_array_almost_equal(model._factor_graph.factors[3].probs, [[0.4545, 0.5455]], 4) assert_array_almost_equal(model._factor_graph.factors[2].probs, - [[0.5455, 0.4545]], 4) - assert_array_almost_equal(model._factor_graph.factors[1].probs, [[[0.0833, 0.1667], [0.2500, 0.0000]], - [[0.0833, 0.1667], - [0.0833, 0.1667]]], 4) + [[0.1250, 0.1250], + [0.1250, 0.1250]]], 4) def test_summarize(X, distributions): diff --git a/tests/test_bayesian_network_structure_learning.py b/tests/test_bayesian_network_structure_learning.py index e62d8c29..174e0d8d 100644 --- a/tests/test_bayesian_network_structure_learning.py +++ b/tests/test_bayesian_network_structure_learning.py @@ -152,7 +152,7 @@ def test_categorical_chow_liu_raises(X, w): def test_categorical_exact(X): structure = _categorical_exact(X) - assert_tuple_equal(structure, ((), (0,), (), (0, 1))) + assert_tuple_equal(structure, ((), (0,), (0, 1), ())) structure = _categorical_exact(X, max_parents=1) assert_tuple_equal(structure, ((), (0,), (), ())) @@ -160,7 +160,7 @@ def test_categorical_exact(X): def test_categorical_exact_weighted(X, w): structure = _categorical_exact(X, w) - assert_tuple_equal(structure, ((), (0,), (), (0, 1))) + assert_tuple_equal(structure, ((), (0,), (0, 1), ())) structure = _categorical_exact(X, w, max_parents=1) assert_tuple_equal(structure, ((), (0,), (), ())) @@ -169,15 +169,15 @@ def test_categorical_exact_weighted(X, w): def test_categorical_exact_exclude_parents(X): exclude_parents = ((), (2,), (), (1,)) structure = _categorical_exact(X, exclude_parents=exclude_parents) - assert_tuple_equal(structure, ((), (), (0, 3), (0,))) + assert_tuple_equal(structure, ((), (), (0,), (0, 2))) structure = _categorical_exact(X, exclude_parents=exclude_parents, max_parents=1) - assert_tuple_equal(structure, ((), (), (0,), ())) + assert_tuple_equal(structure, ((), (0,), (), ())) exclude_parents = ((), (2,), (), (0, 1)) structure = _categorical_exact(X, exclude_parents=exclude_parents) - assert_tuple_equal(structure, ((3,), (), (0,3), ())) + assert_tuple_equal(structure, ((2, 3), (), (), (2,))) def test_categorical_exact_large():