Skip to content

Commit

Permalink
v1.0.1
Browse files Browse the repository at this point in the history
  • Loading branch information
jmschrei committed May 22, 2023
1 parent 1685e9c commit d9d3748
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 43 deletions.
2 changes: 1 addition & 1 deletion pomegranate/__init__.py
@@ -1 +1 @@
__version__ = "1.0.0"
__version__ = "1.0.1"
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -2,7 +2,7 @@

setup(
name='pomegranate',
version='1.0.0',
version='1.0.1',
author='Jacob Schreiber',
author_email='jmschreiber91@gmail.com',
packages=['pomegranate', 'pomegranate.distributions', 'pomegranate.hmm'],
Expand Down
15 changes: 13 additions & 2 deletions tests/test_bayesian_network.py
Expand Up @@ -16,6 +16,17 @@
from numpy.testing import assert_array_almost_equal


def assert_tuple_equal(x, y):
assert len(x) == len(y)

for i in range(len(x)):
assert len(x[i]) == len(y[i])

if len(x[i]) > 0:
for j in range(len(x[i])):
assert x[i][j] == y[i][j]


@pytest.fixture
def X():
return [[1, 2, 0, 0],
Expand Down Expand Up @@ -991,7 +1002,7 @@ def test_learn_structure_chow_liu(X):
model = BayesianNetwork(algorithm='chow-liu')
model.fit(X)

assert_array_equal(model._parents, ((), (0,), (1,), (1,)))
assert_tuple_equal(model._parents, ((), (0,), (1,), (1,)))

assert_array_almost_equal(model.distributions[0].probs,
[[0.4545, 0.5455]], 4)
Expand Down Expand Up @@ -1019,7 +1030,7 @@ def test_learn_structure_exact(X):
model = BayesianNetwork(algorithm='exact')
model.fit(X)

assert_array_equal(model._parents, ((), (), (), (1, 2)))
assert_tuple_equal(model._parents, ((), (), (), (1, 2)))

assert_array_almost_equal(model.distributions[0].probs,
[[0.4545, 0.5455]], 4)
Expand Down
88 changes: 49 additions & 39 deletions tests/test_bayesian_network_structure_learning.py
Expand Up @@ -17,7 +17,17 @@

from nose.tools import assert_raises
from numpy.testing import assert_array_almost_equal
from numpy.testing import assert_array_equal


def assert_tuple_equal(x, y):
assert len(x) == len(y)

for i in range(len(x)):
assert len(x[i]) == len(y[i])

if len(x[i]) > 0:
for j in range(len(x[i])):
assert x[i][j] == y[i][j]


@pytest.fixture
Expand Down Expand Up @@ -64,40 +74,40 @@ def w():

def test_categorical_chow_liu(X):
structure = _categorical_chow_liu(X)
assert_array_equal(structure, ((), (0,), (0,), (0,)))
assert_tuple_equal(structure, ((), (0,), (0,), (0,)))

structure = _categorical_chow_liu(X, root=1)
assert_array_equal(structure, ((1,), (), (1,), (1,)))
assert_tuple_equal(structure, ((1,), (), (1,), (1,)))

structure = _categorical_chow_liu(X, root=2)
assert_array_equal(structure, ((2,), (2,), (), (2,)))
assert_tuple_equal(structure, ((2,), (2,), (), (2,)))


def test_categorical_chow_liu_weighted(X, w):
structure = _categorical_chow_liu(X, w)
assert_array_equal(structure, ((), (0,), (0,), (0,)))
assert_tuple_equal(structure, ((), (0,), (0,), (0,)))

structure = _categorical_chow_liu(X, w, root=1)
assert_array_equal(structure, ((1,), (), (1,), (1,)))
assert_tuple_equal(structure, ((1,), (), (1,), (1,)))

structure = _categorical_chow_liu(X, w, root=2)
assert_array_equal(structure, ((2,), (2,), (), (2,)))
assert_tuple_equal(structure, ((2,), (2,), (), (2,)))


def test_categorical_chow_liu_large():
numpy.random.seed(0)
X = numpy.random.randint(3, size=(50, 10))

structure = _categorical_chow_liu(X)
assert_array_equal(structure, ((), (0,), (1,), (7,), (0,), (0,), (7,), (2,),
assert_tuple_equal(structure, ((), (0,), (1,), (7,), (0,), (0,), (7,), (2,),
(7,), (0,)))

structure = _categorical_chow_liu(X, root=1)
assert_array_equal(structure, ((1,), (), (1,), (7,), (0,), (0,), (7,), (2,),
assert_tuple_equal(structure, ((1,), (), (1,), (7,), (0,), (0,), (7,), (2,),
(7,), (0,)))

structure = _categorical_chow_liu(X, root=2)
assert_array_equal(structure, ((1,), (2,), (), (7,), (0,), (0,), (7,), (2,),
assert_tuple_equal(structure, ((1,), (2,), (), (7,), (0,), (0,), (7,), (2,),
(7,), (0,)))


Expand All @@ -106,15 +116,15 @@ def test_categorical_chow_liu_large_pseudocount():
X = numpy.random.randint(3, size=(50, 10))

structure = _categorical_chow_liu(X, pseudocount=10)
assert_array_equal(structure, ((), (2,), (4,), (2,), (0,), (3,), (3,), (2,),
assert_tuple_equal(structure, ((), (2,), (4,), (2,), (0,), (3,), (3,), (2,),
(2,), (6,)))

structure = _categorical_chow_liu(X, root=1, pseudocount=10)
assert_array_equal(structure, ((4,), (), (1,), (2,), (2,), (3,), (3,), (2,),
assert_tuple_equal(structure, ((4,), (), (1,), (2,), (2,), (3,), (3,), (2,),
(2,), (6,)))

structure = _categorical_chow_liu(X, root=2, pseudocount=10)
assert_array_equal(structure, ((4,), (2,), (), (2,), (2,), (3,), (3,), (2,),
assert_tuple_equal(structure, ((4,), (2,), (), (2,), (2,), (3,), (3,), (2,),
(2,), (6,)))


Expand Down Expand Up @@ -142,55 +152,55 @@ def test_categorical_chow_liu_raises(X, w):

def test_categorical_exact(X):
structure = _categorical_exact(X)
assert_array_equal(structure, ((), (2,), (), (1, 2)))
assert_tuple_equal(structure, ((), (2,), (), (1, 2)))

structure = _categorical_exact(X, max_parents=1)
assert_array_equal(structure, ((), (2,), (), ()))
assert_tuple_equal(structure, ((), (2,), (), ()))


def test_categorical_exact_weighted(X, w):
structure = _categorical_exact(X, w)
assert_array_equal(structure, ((), (2,), (), (1, 2)))
assert_tuple_equal(structure, ((), (2,), (), (1, 2)))

structure = _categorical_exact(X, w, max_parents=1)
assert_array_equal(structure, ((), (2,), (), ()))
assert_tuple_equal(structure, ((), (2,), (), ()))


def test_categorical_exact_exclude_parents(X):
exclude_parents = ((), (2,), (), (1,))
structure = _categorical_exact(X, exclude_parents=exclude_parents)
assert_array_equal(structure, ((), (), (0, 3), (0,)))
assert_tuple_equal(structure, ((), (), (0, 3), (0,)))

structure = _categorical_exact(X, exclude_parents=exclude_parents,
max_parents=1)
assert_array_equal(structure, ((), (), (), (0,)))
assert_tuple_equal(structure, ((), (), (), (0,)))

exclude_parents = ((), (2,), (), (0, 1))
structure = _categorical_exact(X, exclude_parents=exclude_parents)
assert_array_equal(structure, ((2, 3), (), (), (2,)))
assert_tuple_equal(structure, ((2, 3), (), (), (2,)))


def test_categorical_exact_large():
numpy.random.seed(0)
X = numpy.random.randint(3, size=(50, 10))

structure = _categorical_exact(X)
assert_array_equal(structure, ((), (6, 7), (), (1, 6, 7), (), (1, 3, 6, 7),
assert_tuple_equal(structure, ((), (6, 7), (), (1, 6, 7), (), (1, 3, 6, 7),
(), (6,), (), ()))

structure = _categorical_exact(X, max_parents=1)
assert_array_equal(structure, ((), (), (), (), (), (), (), (3,), (), ()))
assert_tuple_equal(structure, ((), (), (), (), (), (), (), (3,), (), ()))

structure = _categorical_exact(X, max_parents=2)
assert_array_equal(structure, ((), (), (6, 7), (), (), (), (), (6,), (),
assert_tuple_equal(structure, ((), (), (6, 7), (), (), (), (), (6,), (),
()))

structure = _categorical_exact(X, max_parents=3)
assert_array_equal(structure, ((), (), (), (), (), (6, 7), (), (6,),
assert_tuple_equal(structure, ((), (), (), (), (), (6, 7), (), (6,),
(5, 6, 7), ()))

structure = _categorical_exact(X, max_parents=4)
assert_array_equal(structure, ((), (6, 7), (), (1, 6, 7), (), (1, 3, 6, 7),
assert_tuple_equal(structure, ((), (6, 7), (), (1, 6, 7), (), (1, 3, 6, 7),
(), (6,), (), ()))


Expand All @@ -199,22 +209,22 @@ def test_categorical_exact_large_pseudocount():
X = numpy.random.randint(3, size=(50, 8))

structure = _categorical_exact(X, pseudocount=10)
assert_array_equal(structure, ((), (), (), (), (), (), (), ()))
assert_tuple_equal(structure, ((), (), (), (), (), (), (), ()))

structure = _categorical_exact(X, pseudocount=5)
assert_array_equal(structure, ((), (), (), (), (), (), (), ()))
assert_tuple_equal(structure, ((), (), (), (), (), (), (), ()))

structure = _categorical_exact(X, pseudocount=2)
assert_array_equal(structure, ((), (), (), (), (), (), (), ()))
assert_tuple_equal(structure, ((), (), (), (), (), (), (), ()))

structure = _categorical_exact(X, pseudocount=1)
assert_array_equal(structure, ((), (), (), (), (), (), (), ()))
assert_tuple_equal(structure, ((), (), (), (), (), (), (), ()))

structure = _categorical_exact(X, pseudocount=0.1)
assert_array_equal(structure, ((), (), (), (), (), (), (), ()))
assert_tuple_equal(structure, ((), (), (), (), (), (), (), ()))

structure = _categorical_exact(X, pseudocount=1e-8)
assert_array_equal(structure, ((5,), (), (), (), (), (), (), ()))
assert_tuple_equal(structure, ((5,), (), (), (), (), (), (), ()))


###
Expand All @@ -227,19 +237,19 @@ def test_categorical_learn_structure_chow_liu():

structure1 = _categorical_chow_liu(X)
structure2 = _learn_structure(X, algorithm='chow-liu')
assert_array_equal(structure1, structure2)
assert_tuple_equal(structure1, structure2)

structure1 = _categorical_chow_liu(X, root=1)
structure2 = _learn_structure(X, root=1, algorithm='chow-liu')
assert_array_equal(structure1, structure2)
assert_tuple_equal(structure1, structure2)

structure1 = _categorical_chow_liu(X, pseudocount=50)
structure2 = _learn_structure(X, pseudocount=50, algorithm='chow-liu')
assert_array_equal(structure1, structure2)
assert_tuple_equal(structure1, structure2)

structure1 = _categorical_chow_liu(X, w)
structure2 = _learn_structure(X, w, algorithm='chow-liu')
assert_array_equal(structure1, structure2)
assert_tuple_equal(structure1, structure2)


def test_categorical_learn_structure_exact():
Expand All @@ -249,19 +259,19 @@ def test_categorical_learn_structure_exact():

structure1 = _categorical_exact(X)
structure2 = _learn_structure(X, algorithm='exact')
assert_array_equal(structure1, structure2)
assert_tuple_equal(structure1, structure2)

structure1 = _categorical_exact(X, max_parents=1)
structure2 = _learn_structure(X, max_parents=1, algorithm='exact')
assert_array_equal(structure1, structure2)
assert_tuple_equal(structure1, structure2)

structure1 = _categorical_exact(X, pseudocount=50)
structure2 = _learn_structure(X, pseudocount=50, algorithm='exact')
assert_array_equal(structure1, structure2)
assert_tuple_equal(structure1, structure2)

structure1 = _categorical_exact(X, w)
structure2 = _learn_structure(X, w, algorithm='exact')
assert_array_equal(structure1, structure2)
assert_tuple_equal(structure1, structure2)


###
Expand Down

0 comments on commit d9d3748

Please sign in to comment.