Skip to content

Using flax.nn.scan with modules that have callable containing kwargs #3764

Discussion options

You must be logged in to vote

My potentially suboptimal solution, any pointers or discussion if others have encountered the same issue would be appreciated:

class Encoder1DBlock(nn.Module):
    """Transformer encoder layer."""

    layer_norm: DictConfig 
    dropout: DictConfig
    self_attention: DictConfig
    mlp_block: DictConfig
    train: Optional[bool] = None
    mask: Optional[ArrayLike] = None
    
    @nn.compact
    def __call__(self, inputs, mask=None, train=None):

        train = nn.merge_param('train', self.train, train)                                                                                                                                                     
        mask = nn.merge_param('mask', 

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by peterdavidfagan
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.
1 participant