Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error loading a saved model to run inference (using ddp_notebook strategy) #19869

Open
carlos-havier opened this issue May 15, 2024 · 1 comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.1.x

Comments

@carlos-havier
Copy link

Bug description

Lightning throws an error when using a saved model to run inference, while using the ddp_notebook strategy.

In this case, it throws the error: "RuntimeError: Lightning can't create new processes if CUDA is already initialized. Did you manually call torch.cuda.* functions, have moved the model to the device, or allocated memory on the GPU any other way? Please remove any such calls, or change the selected strategy. You will have to restart the Python kernel."

I submit a minimum working example to reproduce the error.

What version are you seeing the problem on?

v2.1

How to reproduce the bug

https://colab.research.google.com/drive/1sxHACc95h-LcR48t3NYUfteLxaI4EmHB?usp=sharing

Error messages and logs

RuntimeError: Lightning can't create new processes if CUDA is already initialized. Did you manually call torch.cuda.* functions, have moved the model to the device, or allocated memory on the GPU any other way? Please remove any such calls, or change the selected strategy. You will have to restart the Python kernel

Environment

Current environment
  • CUDA:
    • GPU:
      • NVIDIA GeForce RTX 3080 Laptop GPU
    • available: True
    • version: 11.8
  • Lightning:
    • lightning: 2.2.1
    • lightning-utilities: 0.10.1
    • pytorch-lightning: 2.1.3
    • torch: 2.2.1
    • torchaudio: 2.2.1
    • torchmetrics: 1.2.1
    • torchvision: 0.17.1
  • Packages:
    • aiohttp: 3.9.3
    • aiosignal: 1.3.1
    • alembic: 1.13.1
    • anyio: 4.3.0
    • argon2-cffi: 23.1.0
    • argon2-cffi-bindings: 21.2.0
    • arrow: 1.3.0
    • asttokens: 2.4.1
    • async-lru: 2.0.4
    • attrs: 23.2.0
    • babel: 2.14.0
    • beautifulsoup4: 4.12.3
    • bleach: 6.1.0
    • brotli: 1.0.9
    • cached-property: 1.5.2
    • certifi: 2024.2.2
    • cffi: 1.16.0
    • charset-normalizer: 2.0.4
    • colorama: 0.4.6
    • colorlog: 6.8.2
    • comm: 0.2.1
    • contourpy: 1.2.0
    • cycler: 0.12.1
    • datasets: 2.18.0
    • debugpy: 1.8.1
    • decorator: 5.1.1
    • defusedxml: 0.7.1
    • dill: 0.3.8
    • entrypoints: 0.4
    • exceptiongroup: 1.2.0
    • executing: 2.0.1
    • fastjsonschema: 2.19.1
    • filelock: 3.13.1
    • fonttools: 4.49.0
    • fqdn: 1.5.1
    • frozenlist: 1.4.1
    • fsspec: 2024.2.0
    • ftfy: 6.1.3
    • gmpy2: 2.1.2
    • greenlet: 3.0.3
    • h11: 0.14.0
    • h2: 4.1.0
    • hpack: 4.0.0
    • httpcore: 1.0.4
    • httpx: 0.27.0
    • huggingface-hub: 0.21.3
    • hyperframe: 6.0.1
    • idna: 3.4
    • importlib-metadata: 7.0.1
    • importlib-resources: 6.1.2
    • ipykernel: 6.29.3
    • ipython: 8.22.2
    • ipywidgets: 8.1.2
    • isoduration: 20.11.0
    • jedi: 0.19.1
    • jinja2: 3.1.3
    • joblib: 1.3.2
    • json5: 0.9.22
    • jsonpointer: 2.4
    • jsonschema: 4.21.1
    • jsonschema-specifications: 2023.12.1
    • jupyter-client: 8.6.0
    • jupyter-core: 5.7.1
    • jupyter-events: 0.9.0
    • jupyter-lsp: 2.2.4
    • jupyter-server: 2.13.0
    • jupyter-server-terminals: 0.5.2
    • jupyterlab: 4.1.4
    • jupyterlab-pygments: 0.3.0
    • jupyterlab-server: 2.25.4
    • jupyterlab-widgets: 3.0.10
    • kiwisolver: 1.4.5
    • lightning: 2.2.1
    • lightning-utilities: 0.10.1
    • mako: 1.3.2
    • markupsafe: 2.1.3
    • matplotlib: 3.8.3
    • matplotlib-inline: 0.1.6
    • mistune: 3.0.2
    • mkl-fft: 1.3.8
    • mkl-random: 1.2.4
    • mkl-service: 2.4.0
    • mpmath: 1.3.0
    • multidict: 6.0.5
    • multiprocess: 0.70.16
    • munkres: 1.1.4
    • nbclient: 0.8.0
    • nbconvert: 7.16.2
    • nbformat: 5.9.2
    • nest-asyncio: 1.6.0
    • networkx: 3.1
    • notebook-shim: 0.2.4
    • numpy: 1.26.4
    • optuna: 3.5.0
    • overrides: 7.7.0
    • p-tqdm: 1.4.0
    • packaging: 23.2
    • pandas: 2.2.1
    • pandocfilters: 1.5.0
    • parso: 0.8.3
    • pathos: 0.3.2
    • patsy: 0.5.6
    • pexpect: 4.9.0
    • pickleshare: 0.7.5
    • pillow: 10.2.0
    • pip: 23.3.1
    • pkgutil-resolve-name: 1.3.10
    • platformdirs: 4.2.0
    • ply: 3.11
    • pox: 0.3.4
    • ppft: 1.7.6.8
    • prometheus-client: 0.20.0
    • prompt-toolkit: 3.0.42
    • psutil: 5.9.8
    • ptyprocess: 0.7.0
    • pure-eval: 0.2.2
    • pyarrow: 15.0.0
    • pyarrow-hotfix: 0.6
    • pycparser: 2.21
    • pygments: 2.17.2
    • pyparsing: 3.1.1
    • pyqt5: 5.15.10
    • pyqt5-sip: 12.13.0
    • pysocks: 1.7.1
    • python-dateutil: 2.9.0
    • python-json-logger: 2.0.7
    • pytorch-lightning: 2.1.3
    • pytz: 2024.1
    • pyyaml: 6.0.1
    • pyzmq: 25.1.2
    • referencing: 0.33.0
    • regex: 2023.12.25
    • requests: 2.31.0
    • rfc3339-validator: 0.1.4
    • rfc3986-validator: 0.1.1
    • rpds-py: 0.18.0
    • safetensors: 0.4.2
    • scikit-learn: 1.4.1.post1
    • scipy: 1.12.0
    • seaborn: 0.13.2
    • send2trash: 1.8.2
    • setuptools: 68.2.2
    • sip: 6.7.12
    • six: 1.16.0
    • sniffio: 1.3.1
    • soupsieve: 2.5
    • sqlalchemy: 2.0.28
    • stack-data: 0.6.2
    • statsmodels: 0.14.1
    • sympy: 1.12
    • terminado: 0.18.0
    • threadpoolctl: 3.3.0
    • timm: 0.9.16
    • tinycss2: 1.2.1
    • tokenizers: 0.15.2
    • tomli: 2.0.1
    • torch: 2.2.1
    • torchaudio: 2.2.1
    • torchmetrics: 1.2.1
    • torchvision: 0.17.1
    • tornado: 6.4
    • tqdm: 4.66.2
    • traitlets: 5.14.1
    • transformers: 4.38.2
    • triton: 2.2.0
    • types-python-dateutil: 2.8.19.20240311
    • typing-extensions: 4.9.0
    • typing-utils: 0.1.0
    • tzdata: 2024.1
    • uri-template: 1.3.0
    • urllib3: 2.1.0
    • wcwidth: 0.2.13
    • webcolors: 1.13
    • webencodings: 0.5.1
    • websocket-client: 1.7.0
    • wheel: 0.41.2
    • widgetsnbextension: 4.0.10
    • xxhash: 3.4.1
    • yarl: 1.9.4
    • zipp: 3.17.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.11.8
    • release: 5.15.0-107-generic
    • version: Errata in the readme? #117~20.04.1-Ubuntu SMP Tue Apr 30 10:35:57 UTC 2024

More info

No response

@carlos-havier carlos-havier added bug Something isn't working needs triage Waiting to be triaged by maintainers labels May 15, 2024
@LawJarp-A
Copy link

LawJarp-A commented May 23, 2024

@carlos-havier As mentioned in the PyTorch Lightning documentation, when using ddp_notebook, the downside is:

"GPU operations such as moving tensors to the GPU or calling torch.cuda functions before invoking Trainer.fit is not allowed."

This means that there can be no CUDA tensors before calling Trainer.fit. By default, when training, PyTorch Lightning saves the state_dict of the trainer as CUDA when using GPU. So when load from checkpoint, CUDA is initialised. You can verify this with a simple check:

print(torch.cuda.is_initialized())

This can be placed before and after calling:

pl_model = LT_timm_model.load_from_checkpoint(def_log_chkpt)

You'll observe that CUDA is initialized when calling load_from_checkpoint, and once CUDA is initialized here, it cannot be re-initialized in a different context as required by ddp_notebook.

The Fix:

Use map_location as CPU when calling load_from_checkpoint:

pl_model = LT_timm_model.load_from_checkpoint(def_log_chkpt, map_location=torch.device('cpu'))

For your reference, I have added a few lines to debug based on your notebook here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.1.x
Projects
None yet
Development

No branches or pull requests

2 participants