Skip to content

Commit

Permalink
FIX unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jmschrei committed Apr 16, 2023
1 parent 1f874dd commit 1731671
Show file tree
Hide file tree
Showing 39 changed files with 129 additions and 155 deletions.
8 changes: 1 addition & 7 deletions pomegranate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1 @@
from .bayes_classifier import BayesClassifier
from .gmm import GeneralMixtureModel
from .kmeans import KMeans

from .distributions import *

__version__ = "0.5.0"
__version__ = "1.0.0"
8 changes: 4 additions & 4 deletions pomegranate/bayesian_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def add_distribution(self, distribution):
Parameters
----------
distribution: torchegranate.distributions.Distribution
distribution: pomegranate.distributions.Distribution
A distribution object to include as a node. Currently must be a
Categorical or a ConditionalCategorical distribution.
"""
Expand Down Expand Up @@ -225,10 +225,10 @@ def add_edge(self, parent, child):
Parameters
----------
parent: torchegranate.distributions.Distribution
parent: pomegranate.distributions.Distribution
The distribution that the edge begins at.
child: torchegranate.distributions.Distribution
child: pomegranate.distributions.Distribution
The distribution that the edge points to.
"""

Expand Down Expand Up @@ -630,7 +630,7 @@ def _from_structure(X, sample_weight=None, structure=None, pseudocount=0.0):
Returns
-------
model: torchegranate.bayesian_network.BayesianNetwork
model: pomegranate.bayesian_network.BayesianNetwork
The fit Bayesian network.
"""

Expand Down
1 change: 0 additions & 1 deletion pomegranate/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .gamma import Gamma
from .independent_components import IndependentComponents
from .joint_categorical import JointCategorical
from .lognormal import LogNormal
from .normal import Normal
from .poisson import Poisson
from .student_t import StudentT
Expand Down
6 changes: 3 additions & 3 deletions pomegranate/distributions/zero_inflated.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# zeroinflated.py
# zero_inflated.py
# Contact: Jacob Schreiber <jmschreiber91@gmail.com>

import time
Expand Down Expand Up @@ -32,8 +32,8 @@ class ZeroInflated(Distribution):
Parameters
----------
distribution: torchegranate.distributions._distribution.Distribution
A torchegranate distribution object. It should probably be a discrete
distribution: pomegranate.distributions.Distribution
A pomegranate distribution object. It should probably be a discrete
distribution, but does not technically have to be.
priors: tuple, numpy.ndarray, torch.Tensor, or None. shape=(2,), optional
Expand Down
8 changes: 4 additions & 4 deletions pomegranate/factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def add_factor(self, distribution):
Parameters
----------
distribution: torchegranate.distributions.Distribution
distribution: pomegranate.distributions.Distribution
A distribution object to include as a node.
"""

Expand All @@ -180,7 +180,7 @@ def add_marginal(self, distribution):
Parameters
----------
distribution: torchegranate.distributions.Distribution
distribution: pomegranate.distributions.Distribution
A distribution object to include as a node.
"""

Expand All @@ -202,10 +202,10 @@ def add_edge(self, marginal, factor):
Parameters
----------
marginal: torchegranate.distributions.Distribution
marginal: pomegranate.distributions.Distribution
The marginal distribution to include in the edge.
factor: torchegranate.distributions.Distribution
factor: pomegranate.distributions.Distribution
The factor distribution to include in the edge.
"""

Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import pytest

from torchegranate.distributions import Bernoulli
from pomegranate.distributions import Bernoulli

from ._utils import _test_initialization_raises_one_parameter
from ._utils import _test_initialization
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import pytest

from torchegranate.distributions import Categorical
from pomegranate.distributions import Categorical

from ._utils import _test_initialization_raises_one_parameter
from ._utils import _test_predictions
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_conditional_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import pytest

from torchegranate.distributions import ConditionalCategorical
from pomegranate.distributions import ConditionalCategorical

from ._utils import _test_initialization_raises_one_parameter
from ._utils import _test_initialization
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_dirac_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import pytest

from torchegranate.distributions import DiracDelta
from pomegranate.distributions import DiracDelta

from ._utils import _test_initialization_raises_one_parameter
from ._utils import _test_initialization
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import pytest

from torchegranate.distributions import Exponential
from pomegranate.distributions import Exponential

from ._utils import _test_initialization_raises_one_parameter
from ._utils import _test_initialization
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import pytest

from torchegranate.distributions import Gamma
from pomegranate.distributions import Gamma

from ._utils import _test_initialization_raises_two_parameters
from ._utils import _test_initialization
Expand Down
8 changes: 4 additions & 4 deletions tests/distributions/test_independent_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import torch
import pytest

from torchegranate.distributions import Exponential
from torchegranate.distributions import Gamma
from torchegranate.distributions import Categorical
from torchegranate.distributions import IndependentComponents
from pomegranate.distributions import Exponential
from pomegranate.distributions import Gamma
from pomegranate.distributions import Categorical
from pomegranate.distributions import IndependentComponents

from ._utils import _test_initialization_raises_one_parameter
from ._utils import _test_initialization
Expand Down
4 changes: 2 additions & 2 deletions tests/distributions/test_joint_categorical.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# test_JointCategorical.py
# test_joint_categorical.py
# Contact: Jacob Schreiber <jmschreiber91@gmail.com>

import numpy
import torch
import pytest

from torchegranate.distributions import JointCategorical
from pomegranate.distributions import JointCategorical

from ._utils import _test_initialization_raises_one_parameter
from ._utils import _test_initialization
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_normal_diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import pytest

from torchegranate.distributions import Normal
from pomegranate.distributions import Normal

from ._utils import _test_initialization_raises_two_parameters
from ._utils import _test_initialization
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_normal_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import pytest

from torchegranate.distributions import Normal
from pomegranate.distributions import Normal

from ._utils import _test_initialization_raises_two_parameters
from ._utils import _test_initialization
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import pytest

from torchegranate.distributions import Poisson
from pomegranate.distributions import Poisson

from ._utils import _test_initialization_raises_one_parameter
from ._utils import _test_initialization
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_student_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import pytest

from torchegranate.distributions import StudentT
from pomegranate.distributions import StudentT

from ._utils import _test_initialization_raises_two_parameters
from ._utils import _test_initialization
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import pytest

from torchegranate.distributions import Uniform
from pomegranate.distributions import Uniform

from ._utils import _test_initialization_raises_two_parameters
from ._utils import _test_initialization
Expand Down
6 changes: 3 additions & 3 deletions tests/hmm/test_dense_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import torch
import pytest

from torchegranate.hmm import DenseHMM
from torchegranate.distributions import Exponential
from torchegranate.distributions import Gamma
from pomegranate.hmm import DenseHMM
from pomegranate.distributions import Exponential
from pomegranate.distributions import Gamma

from ..distributions._utils import _test_initialization_raises_one_parameter
from ..distributions._utils import _test_initialization
Expand Down
6 changes: 3 additions & 3 deletions tests/hmm/test_sparse_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import torch
import pytest

from torchegranate.hmm import SparseHMM
from torchegranate.distributions import Exponential
from torchegranate.distributions import Gamma
from pomegranate.hmm import SparseHMM
from pomegranate.distributions import Exponential
from pomegranate.distributions import Gamma

from ..distributions._utils import _test_initialization_raises_one_parameter
from ..distributions._utils import _test_initialization
Expand Down
4 changes: 2 additions & 2 deletions tests/test_bayes_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import torch
import pytest

from torchegranate.bayes_classifier import BayesClassifier
from torchegranate.distributions import Exponential
from pomegranate.bayes_classifier import BayesClassifier
from pomegranate.distributions import Exponential

from .distributions._utils import _test_initialization_raises_one_parameter
from .distributions._utils import _test_initialization
Expand Down
8 changes: 4 additions & 4 deletions tests/test_bayesian_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import torch
import pytest

from torchegranate.bayesian_network import BayesianNetwork
from torchegranate.distributions import Exponential
from torchegranate.distributions import Categorical
from torchegranate.distributions import ConditionalCategorical
from pomegranate.bayesian_network import BayesianNetwork
from pomegranate.distributions import Exponential
from pomegranate.distributions import Categorical
from pomegranate.distributions import ConditionalCategorical


from nose.tools import assert_raises
Expand Down
14 changes: 7 additions & 7 deletions tests/test_bayesian_network_structure_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
import torch
import pytest

from torchegranate.bayesian_network import BayesianNetwork
from pomegranate.bayesian_network import BayesianNetwork

from torchegranate.bayesian_network import _from_structure
from torchegranate.bayesian_network import _learn_structure
from torchegranate.bayesian_network import _categorical_exact
from torchegranate.bayesian_network import _categorical_chow_liu
from pomegranate.bayesian_network import _from_structure
from pomegranate.bayesian_network import _learn_structure
from pomegranate.bayesian_network import _categorical_exact
from pomegranate.bayesian_network import _categorical_chow_liu

from torchegranate.distributions import Categorical
from torchegranate.distributions import ConditionalCategorical
from pomegranate.distributions import Categorical
from pomegranate.distributions import ConditionalCategorical

from nose.tools import assert_raises
from numpy.testing import assert_array_almost_equal
Expand Down
9 changes: 4 additions & 5 deletions tests/test_factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import torch
import pytest

from torchegranate.factor_graph import FactorGraph
from torchegranate.distributions import Exponential
from torchegranate.distributions import Categorical
from torchegranate.distributions import JointCategorical

from pomegranate.factor_graph import FactorGraph
from pomegranate.distributions import Exponential
from pomegranate.distributions import Categorical
from pomegranate.distributions import JointCategorical

from nose.tools import assert_raises
from numpy.testing import assert_array_almost_equal
Expand Down
4 changes: 2 additions & 2 deletions tests/test_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import torch
import pytest

from torchegranate.gmm import GeneralMixtureModel
from torchegranate.distributions import Exponential
from pomegranate.gmm import GeneralMixtureModel
from pomegranate.distributions import Exponential

from .distributions._utils import _test_initialization_raises_one_parameter
from .distributions._utils import _test_initialization
Expand Down
2 changes: 1 addition & 1 deletion tests/test_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import pytest

from torchegranate.kmeans import KMeans
from pomegranate.kmeans import KMeans

from .distributions._utils import _test_initialization_raises_one_parameter
from .distributions._utils import _test_initialization
Expand Down
6 changes: 3 additions & 3 deletions tests/test_markov_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import torch
import pytest

from torchegranate.markov_chain import MarkovChain
from torchegranate.distributions import Categorical
from torchegranate.distributions import ConditionalCategorical
from pomegranate.markov_chain import MarkovChain
from pomegranate.distributions import Categorical
from pomegranate.distributions import ConditionalCategorical

from .distributions._utils import _test_initialization_raises_one_parameter
from .distributions._utils import _test_initialization
Expand Down
6 changes: 3 additions & 3 deletions tests/test_semisupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import torch
import pytest

from torchegranate.gmm import GeneralMixtureModel
from torchegranate.hmm import DenseHMM
from torchegranate.distributions import Exponential
from pomegranate.gmm import GeneralMixtureModel
from pomegranate.hmm import DenseHMM
from pomegranate.distributions import Exponential

from nose.tools import assert_raises
from numpy.testing import assert_array_equal
Expand Down
8 changes: 4 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import torch
import pytest

from torchegranate._utils import _cast_as_tensor
from torchegranate._utils import _update_parameter
from torchegranate._utils import _check_parameter
from torchegranate._utils import partition_sequences
from pomegranate._utils import _cast_as_tensor
from pomegranate._utils import _update_parameter
from pomegranate._utils import _check_parameter
from pomegranate._utils import partition_sequences

from nose.tools import assert_almost_equal
from nose.tools import assert_equal
Expand Down

0 comments on commit 1731671

Please sign in to comment.