Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JaxNumPy functions for GridArrays/GridVariables #110

Open
gemmaellen opened this issue Nov 21, 2021 · 2 comments
Open

JaxNumPy functions for GridArrays/GridVariables #110

gemmaellen opened this issue Nov 21, 2021 · 2 comments

Comments

@gemmaellen
Copy link

gemmaellen commented Nov 21, 2021

Hi! This is a great project, and I'm a big fan of both the machine learning applications here and also some of the smaller, helpful structures, in particular base.grids.

Currently, it is possible to add two GridArrays, but it is not possible to add two GridVariables. So this works fine:

import jax_cfd.base.grids as gd
import jax.numpy as jnp

grid = gd.Grid([4,], domain = [(0, 1),])

array_of_values = jnp.array([2.0, 2.0, 3.0, 4.0])

centered_array = grid.center(array_of_values)

print(centered_array + centered_array)

But this throws an exception:

bc = gd.BoundaryConditions((gd.PERIODIC,))

centered_variable = gd.GridVariable(centered, bc)

print(centered_variable + centered_variable)

I'm happy to have a go at implementing this myself, if someone isn't already working on it.

Also, am I correct in thinking that the way to use a JaxNumPy function on a GridArray is to call it via NumPy? For example, this throws an exception:

print(jnp.abs(centered_array))

But this works:

import numpy as np
print(np.abs(centered_array))

I assume it's implemented this way because NumPy has an automatic mix-in that we can employ to funnel things to the appropriate JaxNumPy function, but JaxNumPy does not.

@shoyer
Copy link
Member

shoyer commented Dec 2, 2021

Hi @gemmaellen -- thanks for your interest!

This was an intentional design choice -- GridVariables have boundary conditions, which we don't know how to propagate automatically (unless using periodic boundaries, which aren't really boundary conditions at all). So we only support math on GridArray objects.

Also, am I correct in thinking that the way to use a JaxNumPy function on a GridArray is to call it via NumPy?

This is correct, I agree it's strange. It's for the simple reason that NumPy supports overriding it's functions on new types but JAX doesn't.

@gemmaellen
Copy link
Author

Oh, I see! Yes, addition and multiplication would also work for matching homogeneous Dirichet/Neumann BCs, but that's a special case, and it wouldn't extend to other functions like sines and cosines and so on. I assume this is also the reason why the "shift" method on a GridVariable returns a GridArray. Thanks for the explanation!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants