Skip to content

How to init BatchNorm in the setup function #3881

Closed Answered by chiamp
davidshen84 asked this question in Q&A
Discussion options

You must be logged in to vote

You can pass in use_running_average at call time:

import flax.linen as nn
import jax, jax.numpy as jnp

class Model(nn.Module):
  def setup(self):
    self.dense = nn.Dense(3)
    self.batchnorm = nn.BatchNorm()
  def __call__(self, x, train):
    return self.dense(self.batchnorm(x, use_running_average=not train))

x = jnp.ones((1, 2))
model = Model()
variables = model.init(jax.random.key(0), x, train=False)
out, batch_stats = model.apply(variables, x, train=True, mutable=['batch_stats'])

You could also pass in use_running_average at setup time, but it will lock that value in:

class Model(nn.Module):
  def setup(self):
    self.dense = nn.Dense(3)
    self.batchnorm = nn.BatchNorm(use_ru…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by davidshen84
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants