You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
2024-05-17 10:18:51.191896: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Array elements: 4294967296
Array size: 17.1799GB
I remember reading somewhere that the above warning can be ignored so I think it this is unlikely related to the issue I'm seeing.
It looks like as long as batch_size * seq_length <= 2 ** 31 then the program will not get stuck. For example, if I change either batch size or seq_length from 2 ** 16 to 2 ** 15 then it works fine. However, changing dtype from float32 to bfloat16 does not fix the problem. Plus I'm using A100 80GB, with batch_size = seq_length = 2 ** 16, dtype=float32 the array only takes roughly 17GB. So it perhaps has nothing to do with memory.
Also when it hangs both GPU and CPU utilization is zero.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.28
jaxlib: 0.4.28
numpy: 1.26.4
python: 3.10.4 (main, Mar 31 2022, 08:41:55) [GCC 7.5.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='cn-g020.server.mila.quebec', release='5.15.0-101-generic', version='#111-Ubuntu SMP Tue Mar 5 20:16:58 UTC 2024', machine='x86_64')
$ nvidia-smi
Fri May 17 10:13:08 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.08 Driver Version: 535.161.08 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA A100-SXM4-80GB On | 00000000:41:00.0 Off | 0 |
| N/A 25C P0 72W / 500W | 424MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 3144259 C python 416MiB |
+---------------------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered:
I also tested jax==0.4.25 with cuda 11 with which I don't see the ptxas version warning (but the kernel still hangs indefinitely), so it likely has nothing to do with that
Description
The following simple pallas kernel that copies an array hangs indefinitely:
Program output:
I remember reading somewhere that the above warning can be ignored so I think it this is unlikely related to the issue I'm seeing.
It looks like as long as
batch_size * seq_length <= 2 ** 31
then the program will not get stuck. For example, if I change eitherbatch size
orseq_length
from2 ** 16
to2 ** 15
then it works fine. However, changingdtype
fromfloat32
tobfloat16
does not fix the problem. Plus I'm using A100 80GB, withbatch_size = seq_length = 2 ** 16, dtype=float32
the array only takes roughly 17GB. So it perhaps has nothing to do with memory.Also when it hangs both GPU and CPU utilization is zero.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: