Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VIT Model] [perf Degradation] [X86] [ARM] torch.compile + weight prepacking results in perf degradation for VIT Transformer model #126391

Open
maajidkhann opened this issue May 16, 2024 · 4 comments
Assignees
Labels
module: inductor oncall: cpu inductor CPU Inductor issues for Intel team to triage oncall: pt2

Comments

@maajidkhann
Copy link
Contributor

maajidkhann commented May 16, 2024

馃悰 Describe the bug

With Pytorch 2.3.0, when we run inferencing for VIT DL Model on CPU's both on x86 (ICELake) and ARM (Graviton3) we see a performance degradation with torch.compile()+weight prepacking over just torch.compile().

image

Note: Here, total time taken is in seconds for 100 iterations. Only Inference time is considered here.
We can clearly observe that the performance is getting worsen after applying weight prepacking optimization which shouldn't be the case.

I further did deep dive analysis and figured out the root cause for this issue in the VIT model.

image

From the snapshot above, we can see that, aten::addmm op is called with torch.compile() and mkldnn::_linear_pointwise (ARM) / mkl::_mkl_linear (x86) is called with torch.compile+weight_prepacking.

We see that aten::addmm is called around 4 times in VIT model execution and each of it has a different shape. For every shape there's performance improvement with weight prepacking except 1 shape which is highlighted in green/orange.
[[3072], [197, 768], [768, 3072], [], [], [197, 3072]]

This is the reason for the perf degradation. This shape scenario has to be analyzed and fixed to improve the performance on CPU (Both X86 and ARM).

Scripts to reproduce:
vit.txt
vit_with_weight_prepacking.txt

How to install:

pip3 install torch==2.3.0

export OMP_NUM_THREADS=32 (As this was tested on 32 core machines. m7g.8xlarge-Graviton3 for ARM and c6i.8xlarge-IceLake for Intel.

Error logs

Model perf degradation with torch.compile()+weightprepacking both on x86 and ARM CPU's:

image

Minified repro

No response

Versions

For ARM:
collect_env_arm.txt

requirements.txt

For Intel:
Didn't do source build for wheel binary. Directly did pip3 install torch==2.3.0 from pypi.
requirements.txt

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire

@leslie-fang-intel
Copy link
Collaborator

@Valentine233 Could you help to take a look?

@Valentine233
Copy link
Collaborator

Tried on SPR with 56 threads.

According to the MKL verbose, the mkl_linear kernel (the highlighted shape in the issue) time has a regression starting from a certain moment:

257.85us -> 737.93us
MKL_VERBOSE SGEMM_COMPUTE(P,N,3072,197,768,0x7fdd360de040,768,0x1f629040,768,0x7ffcf8a68b38,0x1fca4540,3072) 257.85us CNR:OFF Dyn:1 FastMM:1 TID:0聽NThr:56
MKL_VERBOSE SGEMM_COMPUTE(P,N,3072,197,768,0x7fde2a7d8040,768,0x1f99f980,768,0x7ffcf8a68b38,0x2003cdc0,3072) 737.93us CNR:OFF Dyn:1 FastMM:1 TID:0聽NThr:56

By writing a small test case and running torch.addmm 2000 times, we could only see the kernel perf around 250us.

@maajidkhann
Copy link
Contributor Author

257.85us -> 737.93us
Will the regression be further investigated and fixed?

The issue is really when weight prepacking is enabled with torch.compile() as highlighted in orange in the ticket. Even the shape becomes different with weight prepacking enabled compared to [[3072], [197, 768], [768, 3072], [], [], [197, 3072]] which is just torch.compile()

@Valentine233
Copy link
Collaborator

Valentine233 commented May 28, 2024

There are some environment problems for the previous data. With enabling tcmalloc and iomp5 (need to install intel-openmp), the performance with weight prepack is better than that without it.

Tested on Xeon SPR with 56 threads.

Without weight prepacking:

       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls                                                                      Input Shapes
aten::addmm        15.50%     476.529ms        17.17%     527.830ms     439.858us          1200                             [[768], [197, 3072], [3072, 768], [], [], [197, 768]]
aten::addmm        14.05%     432.157ms        17.59%     540.775ms     112.662us          4800                               [[768], [197, 768], [768, 768], [], [], [197, 768]]
aten::addmm        12.28%     377.698ms        14.79%     454.871ms     379.059us          1200                            [[3072], [197, 768], [768, 3072], [], [], [197, 3072]]

With weight prepacking:

       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls                                                                      Input Shapes
mkl::_mkl_linear        17.59%     458.226ms        17.80%     463.612ms      96.586us          4800                               [[197, 768], [2900193, 1], [768, 768], [], []]
mkl::_mkl_linear        14.35%     373.810ms        14.41%     375.306ms     312.755us          1200                             [[197, 3072], [5259489, 1], [768, 3072], [], []]
mkl::_mkl_linear        11.83%     308.105ms        11.89%     309.697ms     258.080us          1200                              [[197, 768], [5259489, 1], [3072, 768], [], []]

@maajidkhann Could you try with the environment parameters mentioned above? Maybe you'd better run with the PyTorch launcher https://github.com/pytorch/pytorch/blob/main/torch/backends/xeon/run_cpu.py.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor oncall: cpu inductor CPU Inductor issues for Intel team to triage oncall: pt2
Projects
None yet
Development

No branches or pull requests

4 participants