-
I am running a Bayesian inference in which at some point I am using a small NN (it is taking 88MBs on my disk). I am trying to compile the whole model and getting this error XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 6727663616 bytes. I came upon a stackoverflow discussion which suggested to set following flags, os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".50"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform" and they also didn't work. My guess is this is due to large size of my NN and the data I am feeding. This gave me the idea can we call non-jit functions from jit functions? This is the gpu I was using at Kaggle, Fri May 17 17:47:11 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 Tesla P100-PCIE-16GB Off | 00000000:00:04.0 Off | 0 |
| N/A 38C P0 31W / 250W | 12576MiB / 16384MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
+---------------------------------------------------------------------------------------+ |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
To answer the question in your title, yes this is possible via Answering your broader question: depending on the details of your code, I suspect a better approach here would be to avoid JIT-compiling your entire workflow, and instead compile small sections of it, and then execute those JIT-compiled sections in sequence. |
Beta Was this translation helpful? Give feedback.
To answer the question in your title, yes this is possible via
jax.pure_callback
; see Exploringpure_callback
for an example of how you might use it. The host-side computation there can be anything, including a call to a non-jit-compiled JAX function.Answering your broader question: depending on the details of your code, I suspect a better approach here would be to avoid JIT-compiling your entire workflow, and instead compile small sections of it, and then execute those JIT-compiled sections in sequence.