JAX implementation of a differentiable PDE solver with jump conditions across irregular interfaces in 3D.

JAX-DIPS implements the neural bootstrapping method (NBM) (see citations below) to train compact neural network surrogate models by leveraging efficient finite discretization methods for treatment of spatial gradients as well as automatic differentiation for training neural network parameters.

  • Compact models: contrary to other training methods, we show complex solutions can be learned by shallow and compact neural networks with 1000x less trainable parameters when using finite discretization residuals for training the networks.

  • Performance: Use of FD for computing the PDE residuals limits automatic differentiation to ONLY first-order AD, which significantly reduces the computational/memory costs associated to evaluating loss of higher order PDEs (as in PINNs). We achieve 10x speedup compared to backpropagation-based solvers.

  • Accuracy: Moreover, use of carefully designed numerical discretization schemes (for example see Gibou, Fedkiw, and Osher 2018) for treating spatial gradients at the presence of discontinuities and irregular interfaces informs the neural network about the mathematical symmetries and constraints (e.g., conservation laws enforced through finite volume discretizations) in local neighborhoods/voxels centered at training points. These extra mathematical constraints improve regularity and accuracy of the learned neural surrogate models for PDEs in three spatial dimensions.

  • Cross pollination of applied mathematics and machine learning: JAX-DIPS makes it possible to leverage advanced preconditioners (for example see algebraic multigrid - AMG - preconditioner in hypre) developed in the high performance scientific computing community for faster and more accurate training of neural network models.

Quick Example

The Poisson solver provides an interface to pass in functions defining different terms of the interfacial PDE given in the form

from jax import numpy as jnp

from jax_dips._jaxmd_modules.util import f32
from jax_dips.domain import mesh
from jax_dips.solvers.poisson import trainer
from jax_dips.utils import io

init_mesh_fn, coord_at = mesh.construct(3)

# --------- Grid nodes for training
xc = jnp.linspace(xmin, xmax, Nx_tr, dtype=f32)
yc = jnp.linspace(ymin, ymax, Ny_tr, dtype=f32)
zc = jnp.linspace(zmin, zmax, Nz_tr, dtype=f32)
gstate_tr = init_mesh_fn(xc, yc, zc)

# --------- Grid nodes for level set representation
xc = jnp.linspace(xmin, xmax, Nx_lvl, dtype=f32)
yc = jnp.linspace(ymin, ymax, Ny_lvl, dtype=f32)
zc = jnp.linspace(zmin, zmax, Nz_lvl, dtype=f32)
gstate_lvl = init_mesh_fn(xc, yc, zc)

# ---------- Grid points for visualization
exc = jnp.linspace(xmin, xmax, Nx_eval, dtype=f32)
eyc = jnp.linspace(ymin, ymax, Ny_eval, dtype=f32)
ezc = jnp.linspace(zmin, zmax, Nz_eval, dtype=f32)
eval_gstate = init_mesh_fn(exc, eyc, ezc)

# ---------- Define the neural network surrogate architecture and optimizer
optimizer_dict: dict = {
    "optimizer_name": "custom",
    "learning_rate": 1e-3,
    "sched": {"scheduler_name": "exponential", "decay_rate": 0.9},

model_dict: dict = {
    "name": None,
    "model_type": "mlp",
    "mlp": {
        "hidden_layers_m": 1,
        "hidden_dim_m": 1,
        "activation_m": "jnp.tanh",
        "hidden_layers_p": 2,
        "hidden_dim_p": 10,
        "activation_p": "jnp.tanh",

# ---------- Define, instantiate, then train/solve the poisson problem
init_fn = trainer.setup(
sim_state, solve_fn = init_fn(
sim_state, epoch_store, loss_epochs = solve_fn(sim_state=sim_state)

# ---------- Save results in vtk file
eval_phi = vmap(phi_fn)(eval_gstate.R)

log = {
    "phi": eval_phi,
    "u": sim_state.solution,
save_name = log_dir + "/" + molecule_name
io.write_vtk_manual(eval_gstate, log, filename=save_name)

Library Structure

Streamlines of solution gradients (left), and jump in solution (right) calculated by the dragon example.


Models are stored at jax_dips.nn module and provided to the Poisson solver through the get_model(**model_dict) API defined in jax_dips.nn.configure module. When adding a new model you should only add it to this API. The model parameters (i.e., model_dict) should be provided to the jax_dips.solvers.poisson.trainer module, and is usually defined in the yaml configuration file for hydra similar to:

  model_type : "mlp"
    hidden_layers_m: 1
    hidden_dim_m: 3
    activation_m: "jnp.tanh"
    hidden_layers_p: 2
    hidden_dim_p: 10
    activation_p: "jnp.tanh"
    res_blocks_m : 3
    res_dim_m : 40
    activation_m : "nn.tanh"
    res_blocks_p : 3
    res_dim_p : 80
    activation_p : "nn.tanh"


Explicit and implicit auto-differentiation is provided through the jaxopt and optax packages (optax is configured from jax_dips.solvers.optimizers module; currently calls to jaxopt are configured directly by the jax_dips.solvers.poisson.trainer module). In the yaml file this can be configured by

      optimizer_name: "custom" # options are "custom", "adam", "rmsprop", "lbfgs"
      learning_rate: 1e-3
      sched:  # learning rate scheduler 
        scheduler_name: "exponential" # options are "exponential", "polynomial"
        decay_rate: 0.9

Currently, choosing lbfgs prompts the jaxopt package with implicit differentiation.

Note: in the current version we support data-parallel training using optax.


Do pytest tests/test_*.py of each of the available tests from the parent directory:

  • test_advection: a sphere is rotated 360 degrees around the box to replicate initial configuration. The L2 error in level-set function should be less than 1e-4 to pass. The advection is performed using semi-Lagrangian scheme with Sussman reinitialization.
  • test_reinitialization: starting from a sphere level-set function with -1 inside sphere and +1 outside, we repeatedly perform Sussman reinitialization until the signed-distance property of the level-set is achieved. Center of the box should have level-set value equal to radius of the sphere, and corner of the box should be at a pre-specified distance to pass.
  • test_geometric_integrations: integrating surface area of a sphere along with its volume. Small differences with associated theoretical values are expected to pass.
  • test_poisson: tests for both the pointwise and the grid-based Poisson solvers over a star and a sphere interfaces. Note that in the current implementation the grid-based solver does not support batching and is therefore faster. Fixing this issue will be done in the future versions.


To install the latest released version from PyPI do pip install jax-dips. If you want to create a dedicated virtual environment for jax-dips you could use python3 -m venv <my-virtual-env> and then activate it by source <my-virtual-env>/bin/activate. Then you can install jax-dips inside this environment to make sure it doesn't interfere with your existing library installations. Note jax-dips will install jax on your machine, therefore it is recommended to use a virtual environment.

Development & Usage

0. Hardware Requirements

The default hardware requirements considered below is on NVIDIA GPU with CUDA version >=12.* and CUDA driver version >=530. However, if you are using a different CUDA version you should replace the python wheel address on the first line of requirements.txt with your desired wheel from If you are using a different hardware (AMD GPU, TPU, etc.) you can also modify the requirements.txt accordingly or build from source (modify Dockerfile) based on the instructions at If you are facing issues with installation please don't to raise an issue at and we will add support for your special hardware.

1. Virtual Environment

Create a virtual environment by running the following command


and the env_jax_dips virtual environment will be created. Then you can launch into this environment by source env_jax_dips/bin/activate. After you are done, deactivate.

2. Docker

Docker images provide an isolated and consistent runtime environment, ensuring that the application behaves the same regardless of the host system. We recommend using the docker image provided here as it is fully loaded with libraries for datacenter scale simulatiopns and optimized for NVIDIA GPUs. For a full list of the supported software and specific versions that come packaged with this container image see the Frameworks Support Matrix


First you will need to install the nvidia driver and docker engines:

Container Settings

To personalize your development environment you need to set up a .env file that contains       # default docker image available for download! Change this if you want to build new docker images and push to your preferred docker registry
DATA_PATH=/data/                                   # default data path inside your docker container, will mirror your DATA_MOUNT_PATH directory
DATA_MOUNT_PATH=/data                              # default data mount path inside your machine, will mirror into your docker container's DATA_PATH directory
RESULT_PATH=/results/                              # default result path inside your docker container
RESULT_MOUNT_PATH=/results/                        # default result mount path inside your machine
REGISTRY=<your-preferred-registry-name>            # (optional) choices are (default),,, etc.
REGISTRY_USER=<your-registry-username>             # (optional) your username to connect to docker registry
REGISTRY_ACCESS_TOKEN=<your-registry-access-token> # (optional) your access token to connect to docker registry
WANDB_API_KEY=NotSpecified                         # (optional) your API key to connect to Weights and Biases service
JUPYTER_PORT=8888                                  # (optional) port to connect to jupyter server

Pull development container

Currently the latest docker image available on Docker Hub is available at Instead of building the container, you can only pull the latest docker image by running

./ pull

which pulls from docker hub; i.e., equivalent to $ docker pull pourion/jax_dips:latest.

Build development container

Alternatively you can build the container by running the following command

./ build

In case you want to add additional libraries to your container this is the recommended way.

Start developement container

This will create a container and places the user in the container with source code mounted.

./ dev

Additionally, you can run the container in background without having your terminal jump into the container. This can be done by passing the -d flag for daemon:

./ dev -d

You can always attach your teminal to the running jax_dips container by

./ attach

Development in VS Code

Once the container is created and is running on your machine, user can attach to this container from VS code; i.e., you need to install Microsoft's Dev Containers extension in your VS Code, then Ctrl+Shift+P and choose Dev Containers: Attach to Running Container..., then choose the jax_dips container from the list of running containers on your machine.


If you use JAX-DIPS in your research please use the following citations:

  title={JAX-DIPS: Neural bootstrapping of finite discretization methods and application to elliptic problems with discontinuities},
  author={Mistani, Pouria and Pakravan, Samira and Ilango, Rajesh and Gibou, Frederic},
  journal={arXiv preprint arXiv:2210.14312},

  title={Neuro-symbolic partial differential equation solver},
  author={Mistani, Pouria and Pakravan, Samira and Ilango, Rajesh and Choudhry, Sanjay and Gibou, Frederic},
  journal={arXiv preprint arXiv:2210.14907},

Contributing to JAX-DIPS

  • Reporting bugs. To report a bug please open an issue in the GitHub Issues.
  • Suggesting enhancements. To submit an enhancement suggestion, including new features or improvements to existing functionality, let us know by opening an issue in the GitHub Issues.
  • Pull requests. If you made improvements to JAX-DIPS, fixed a bug, or added a new example, feel free to send us a pull-request.

The Team

JAX-DIPS was developed by Pouria Mistani and Samira Pakravan and Rajesh Ilango under the supervision of Prof. Frederic Gibou during 2019-2022 at University of California Santa Barbara. This project was partially funded by the US Office of Naval Research.


LGPL-2.1 License