Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
jmschrei committed Apr 17, 2023
1 parent e75516f commit 182f279
Showing 1 changed file with 74 additions and 0 deletions.
74 changes: 74 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,61 @@ This rewrite was motivated by four main reasons:
- <b>Community Contribution</b>: A challenge that many people faced when using pomegranate was that they could not extend it because they did not know Cython, and even if they did know it, coding in Cython is a pain. I felt this pain every time I tried adding a new feature or fixing a bug. Using PyTorch as the backend significantly reduces the amount of effort needed to add in new features.
- <b>Interoperability</b>: Libraries like PyTorch offer a unique ability to not just utilize their computational backends but to better integrate into existing deep learning resources. This rewrite should make it easier for people to merge probabilistic models with deep learning models.

### High-level Changes

1. General
- The entire codebase has been rewritten in PyTorch and all models are instances of `torch.nn.Module`
- This codebase is checked by a comprehensive suite of >800 unit tests calling assert statements several thousand times, much more than previous versions.

2. Features
- All models and methods now have GPU support
- All models and methods now have support for half/mixed precision
- Serialization is now handled by PyTorch, yielding more compact and efficient I/O
- Missing values are now supported through torch.masked.MaskedTensor objects
- Prior probabilities can now be passed to all relevant models and methods and enable more comprehensive/flexible semi-supervised learning than before

3. Models
- All distributions are now multivariate by default, supporting speedups through batched operations
- "Distribution" has been removed from distribution objects so that, for example, `NormalDistribution` is now `Normal`.
- Factor graphs are now supported as first-class citizens
- Hidden Markov models have been split into DenseHMM and SparseHMM models which differ in how the transition matrix is encoded, with DenseHMM objects being significantly faster

4. Feature Removals
- `NaiveBayes` has been permanently removed as it is redundant with `BayesClassifier`
- `MarkovNetwork` have been temporarily removed
- Constraint graphs and constrained structure learning for Bayesian networks has been temporarily removed
- Silent states for hidden Markov models have been temporarily removed
- Viterbi for hidden Markov models has been temporarily removed

I hope to soon re-add the features that have been temporarily removed.

## Speed

Most models and methods in pomegranate v1.0.0 are faster than their counterparts in earlier versions. This generally scales by complexity, where one sees only small speedups for simple distributions on small data sets but much larger speedups for more complex models on big data sets, e.g. hidden Markov model training or Bayesian network inference. The notable exception for now is that Bayesian network structure learning, other than Chow-Liu tree building, is still incomplete and not much faster. In the examples below, `torchegranate` refers to the temporarily repository used to develop pomegranate v1.0.0 and `pomegranate` refers to pomegranate v0.14.8.

### K-Means

Who knows what's happening here? Wild.

![image](https://user-images.githubusercontent.com/3916816/232371843-66b9d326-b4de-4da0-bbb1-5eab5f9a4492.png)


### Hidden Markov Models

Dense transition matrix (CPU)

![image](https://user-images.githubusercontent.com/3916816/232370752-58969609-5ee4-417f-a0da-1fbb83763d63.png)

Sparse transition matrix (CPU)

![image](https://user-images.githubusercontent.com/3916816/232371006-20a82e07-3553-4257-987b-d8e9b333933a.png)


### Bayesian Networks
![image](https://user-images.githubusercontent.com/3916816/232370594-e89e66a8-d9d9-4369-ba64-8902d8ec2fcc.png)
![image](https://user-images.githubusercontent.com/3916816/232370632-199d5e99-0cd5-415e-9c72-c4ec9fb7a44c.png)


## Features

> **Note**
Expand Down Expand Up @@ -72,6 +127,18 @@ Parameter containing:
tensor([1.9902, 2.3871, 0.8984, 1.2215], device='cuda:0')
```

### Mixed Precision

pomegranate models can, in theory, operate in the same mixed or low-precision regimes as other PyTorch modules. However, because pomegranate uses more complex operations than most neural networks, this sometimes does not work or help in practice because these operations have not been optimized or implemented in the low-precision regime. So, hopefully this feature will become more useful over time.

```python
>>> X = torch.randn(100, 4)
>>> d = Normal(covariance_type='diag')
>>>
>>> with torch.autocast('cuda', dtype=torch.bfloat16):
>>> d.fit(X)
```

### Serialization

pomegranate distributions are all instances of `torch.nn.Module` and so serialization is the same as any other model and can use any of the other built-in functionality.
Expand Down Expand Up @@ -131,6 +198,13 @@ All algorithms currently treat missingness as something to ignore. As an example

Because not all operations are yet available for MaskedTensors, the following distributions are not yet supported for missing values: Bernoulli, categorical, normal with full covariance, uniform

### Prior Probabilities and Semi-supervised Learning

A new feature in pomegranate v1.0.0 is being able to pass in prior probabilities for each observation for mixture models, Bayes classifiers, and hidden Markov models. These are the prior probability that an observation belongs to a component of the model before evaluating the likelihood and should range between 0 and 1. When these values include a 1.0 for an observation, it is treated as a label, because the likelihood no longer matters in terms of assigning that observation to a state. Hence, one can use these prior probabilities to do labeled training when each observation has a 1.0 for some state, semi-supervised learning when a subset of observations (including when sequences are only partially labeled for hidden Markov models), or more sophisticated forms of weighting when the values are between 0 and 1.

![image](https://user-images.githubusercontent.com/3916816/232373036-39d591e2-e673-450e-ab1c-98e47f0fa6aa.png)


### Frequently Asked Questions

> Why can't we just use `torch.distributions`?
Expand Down

0 comments on commit 182f279

Please sign in to comment.