Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about the random intial velocity in notebook spectral_forced_turbulence.ipynb #99

Open
sifanexisted opened this issue Oct 29, 2021 · 2 comments

Comments

@sifanexisted
Copy link

Hello,

This is a very nice JAX implementation for CFDs! I have a few questions about your latest released notebook spectral_forced_turbulence.ipynb.

  1. What is the exact NS equation that the notebook aims to solve? Would it be possible to provide some detailed expressions of NS equations and boundary conditions? Or would it be possible to provide some references? I think It will be super helpful to people like me who are more familiar with machine learning and know less about CFD.

  2. If I understand correctly, I think the code aims to solve the NS equation (vorticity-velocity form) using the pseudo-spectral method. So we need to enforce periodic boundary conditions for the velocity (or vorticity?) However, when I try to run the following piece of code (which should give me a random velocity field and visualize it at the boundary)

v0 = cfd.initial_conditions.filtered_velocity_field(jax.random.PRNGKey(0), grid, max_velocity, 4)
u = v0[0].data
v = v0[1].data

plt.plot(u[:, 0])
plt.plot(u[:, -1])
plt.show()

plt.plot(u[0, :])
plt.plot(u[-1, :])
plt.show()

I got the following figures below. It seems that the periodic boundary condition is not fully imposed. Could you please take a look or correct me if I am wrong?

bc
bc2

Once again, this is really a project! Looking forward to your reply.

@shoyer
Copy link
Member

shoyer commented Oct 29, 2021

The spectral method code currently always imposes periodic boundary conditions. You can find some references to the particular formulation used in the source code, and eventually we'll probably write it up in a more detailed way in a paper.

To answer your second question, the solution is not exactly the same at the first and last array values, because they are off by one grid point, i.e., we keep track of the values u[0], u[1], ..., u[n - 1]. The value u[n] would be the same as u[0], but we don't use it in the state because that representation would be redundant.

@sifanexisted
Copy link
Author

sifanexisted commented Oct 29, 2021

Thanks for your timely reply!

I guess now I understand what you are talking about. Basically the initial state
v0 = cfd.initial_conditions.filtered_velocity_field(jax.random.PRNGKey(0), grid, max_velocity, 4)
only store the boundary condition for one side, not both.

But I have one additional question. This snippet of code will return an initial velocity u,v with shape (256, 256).

# physical parameters
viscosity = 1e-3
max_velocity = 7
grid = grids.Grid((256, 256), domain=((0, 2 * jnp.pi), (0, 2 * jnp.pi)))
dt = cfd.equations.stable_time_step(max_velocity, .5, viscosity, grid)

# setup step function using crank-nicolson runge-kutta order 4
smooth = True # use anti-aliasing 


# **use predefined settings for Kolmogorov flow**
step_fn = spectral.time_stepping.crank_nicolson_rk4(
    spectral.equations.ForcedNavierStokes2D(viscosity, grid, smooth=smooth), dt)


# run the simulation up until time 25.0 but only save 10 frames for visualization
final_time = 25.0
outer_steps = 10
inner_steps = (final_time // dt) // 10

# create an initial velocity field and compute the fft of the vorticity.
# the spectral code assumes an fft'd vorticity for an initial state
v0 = cfd.initial_conditions.filtered_velocity_field(jax.random.PRNGKey(42), grid, max_velocity, 4)
vorticity0 = cfd.finite_differences.curl_2d(v0).data
vorticity_hat0 = jnp.fft.rfftn(vorticity0)

u = v0[0].data
v = v0[1].data

print(u.shape)
print(v.shape)

As you mentioned, u, v should only contain the value at two boundaries (e.g left and bottom). It seems to imply that u, v, vorticity0 should not be evaluated at the uniform grid
grid = grids.Grid((256, 256), domain=((0, 2 * jnp.pi), (0, 2 * jnp.pi)))
So I am wondering what the exact mesh is used to evaluate the variables such as velocity, vorticity. Should it be something like
grid = grids.Grid((256, 256), domain=((0, 2 * jnp.pi - dx), (0, 2 * jnp.pi - dy)))
where dx, dy is the mesn size.

Looking forward to your reply.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants