Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QUESTION] Is this speed as expected for inference in Bayesian networks? #1088

Open
ferencbartok opened this issue Mar 27, 2024 · 7 comments

Comments

@ferencbartok
Copy link

ferencbartok commented Mar 27, 2024

I'm comparing different libraries for Bayesian inference and I'm wondering if my pomegranate tests show "correct" numbers or maybe I'm using it wrong. (Bit related to this closed issue: #811 )

My 2 test networks consist of 139 and 188 nodes, and 92 and 289 edges. Currently I'm doing inference without any evidence and these are the numbers I get:

  • pomegranate: 0.25 sec; 0.7092 sec
  • pgmpy: 0.0661 sec; 0.3514 sec
  • SMILE: 0.0116 sec; 0.0535 sec

Pomegranate does approximate inference as far as I know; in pgmpy I'm using exact inference (VariableElimination); in SMILE, unfortunately, I'm not sure which is used, but most likely the default one which is exact. Does pomegranate speed seem okayish, or should it be faster (I'm not utilizing GPU - yet)?
Here is my code:

model = BayesianNetwork()
    for node in nodes:
        if not node.parentIds:
            distribution = Categorical([node.probabilities])
            distribution.name = node.id
            node_id_to_distribution[node.id] = distribution
            model.add_distribution(distribution)
        else:
            probabilities_reshaped = create_numpy_format_probabilities(node, node_id_to_node_dictionary)
            distribution = ConditionalCategorical(probabilities_reshaped)
            distribution.name = node.id
            node_id_to_distribution[node.id] = distribution
            model.add_distribution(distribution)
# TODO should not iterate twice...
    for node in nodes:
        for parent_id in node.parentIds:
            model.add_edge(node_id_to_distribution[parent_id], node_id_to_distribution[node.id])
# not setting evidence yet - only this part is included in the actual measurement
X = torch.full((1, len(nodeid_to_node_dictionary)), -1)
X_masked = torch.masked.MaskedTensor(X, mask=X >= 0)
predict_proba_result = model.predict_proba(X_masked)

Am I doing it right? Is there any setting other than max_iter that could speed up inference?

(edit: I saw the comparison between old pomegranate and new and new_with_gpu in the readme and based on that these numbers could seem good, but this depends a lot on edges and state counts as well, not just the node count)

@jmschrei
Copy link
Owner

I haven't checked whether the structure of your network is fine but the gist with PyTorch is that it has a relatively high startup cost for each operation due to the overhead of having to manage tensors that might potentially be on a GPU, but that the variable cost of dealing with bigger operations is smaller. Basically, the cost of doing an addition at all is higher, but as you need to add more numbers together the total time doesn't grow as fast as other libraries. So I'm not too surprised that operations sub 1 second might be a bit slower than CPU-centric implementations. I would expect that pomegranate would scale better than the other libraries though.

I'm not sure if adding a GPU would help too much but it's def worth trying because I'd also be surprised if it didn't help at all. If you have a lot of small operations the I/O of moving things to the GPU may dominate over the gains of doing the operations there. As the table size grows GPUs will help a ton as you can see in the examples in the README.

You're right that there isn't a fixed answer to "which is faster," because it depends on the number of nodes, number of edges, and basically just table size. The rough way of thinking about it is whether it's slow because there's tons of small operations or whether a few BIG operations are dominating the speed. pomegranate does big operations way faster than most other libraries but -- for the reasons mentioned above -- does small operations a bit slower.

@ferencbartok
Copy link
Author

ferencbartok commented Mar 28, 2024

I understand, thank you for the detailed response!

Currently our most time-consuming calculation is this:

  • do initial inference for all nodes/variables with 0 evidence (I use node and variable interchangeably)
  • iterate over ~40-50-60% of our nodes (let's call them "observables"; size depends on the given network)
    • for every observable node set one of the states as evidence (do it for all possible states, which number can usually be around 2-3-4)
    • calculate inference for the other 40-50-60% of nodes*
    • calculate a value for the observable node which shows how much effect/impact changing it's state has on the rest of the variables (~comparing original inference with 0 evidence to the just now calculated probabilities)
      (Maybe this use-case could be related to "do-calculus"?)

*currently for pomegranate I always calculate for all the nodes, not sure yet, how to cut the "targets"

Unfortunately based on your description pytorch and hence pomegranate is not "optimized" for the current implementation. This solution using any library is wasteful as most of the things we recalculate should be the same as before. I'm wondering if you could suggest some way to leverage pomegranate's strength to do this calculation faster?

Will definitely try GPU as well, it is just not possible to do currently.

@ferencbartok
Copy link
Author

ferencbartok commented Apr 4, 2024

I have 2 more questions - hoping it could fit here.

As mentioned earlier, I don't know how to calculate inference for just a given set of variables. Is there a way to do it?
Many times I set 1 evidence and then want to calculate inference just for like x% of the variables, but
predict_proba_result = model.predict_proba(X_masked)
calculates for all the nodes. Would it be possible to not calculate for all but also not set those as some evidence? (Let's say I have A,B,C,D,E. I set A=1 as evidence and only curious about C and D probabilities.)

The other question is related to the comment for predict_proba: ..warning:: This inference is exact given a Bayesian network that has a tree-like structure, but is only approximate for other cases. When the network is acyclic, this procedure will converge, but if the graph contains cycles then there is no guarantee on convergence.
How could this be exact if pomegranate is using sum-product/loopy belief propagation algorithm? I have a tree structure and the numbers I get sometimes are not just a bit off which also points to approximate inference under the hood. Could you elaborate?

@jmschrei
Copy link
Owner

jmschrei commented Apr 4, 2024

In theory, you can avoid calculating the posterior for some variables if they are not in the Markov blanket of the ones you care about. Basically, if variable A could influence the value of variable B, you'd have to calculate both of them even if you only carred about B. But if A could not influence B then, you're right, you wouldn't need to calculate it. Unfortunately, this functionality is not implemented in pomegranate.

The sum-product algorithm is supposed to produce exact estimates for tree-structured networks. Are you running it until convergence (i.e., checking by setting max_iter to be higher? https://github.com/jmschrei/pomegranate/blob/master/pomegranate/bayesian_network.py#L68)

@ferencbartok
Copy link
Author

I understand, thank you!

I've tried testing max_iter earlier and now as well, but it slows down the inference too much and even doubling it didn't give good enough results in terms of actual probabilities.

@jmschrei
Copy link
Owner

jmschrei commented Apr 7, 2024

Sorry that it's too slow for your applications. For networks that have small tables, where not much batching of operations can be done, I agree that other CPU-based implementations may be better for you. If you can point me to implementations or papers about how to speed up your specific problem I'd be happy to read over them and think about whether I can incorporate it, but I can't guarantee I'll have time soon to implement it. Best of luck with your work!

@ferencbartok
Copy link
Author

I understand, no problem! I don't know of these algorithms yet but will have a look around. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants